Transformers disrupting Vision as well — ViT

Anirban Sen
6 min readMar 10, 2024
Photo by David Travis on Unsplash

Transformers are ruling the world as of 2024 with LLMs floating around everyone you see (which are Transformer decorders is most cases). While we have already learned about the overall architecture and about BERT (Transformer encoders) in the blog. In this blog we will learn about Vision Transformers (or ViT) which is using the idea of attention for Vision tasks as well which was dominated by CNNs.

What is ViT ?
It was introduced back in 2021 in the paper — AN IMAGE IS WORTH 16X16 WORDS by Google. The name is quite interesting right ? There is an analogy to it which we will learn. The goal was to reduce the reliance on CNNs and a pure transformer applied directly to sequences of image patches can perform very well on image classification tasks. It attains excellent results compared to state-of-the-art convolutional networks while requiring substantially fewer computational resources to train.

How does it work ?
The first step is to create image patches and embed those as patch embeddings (which are then sent to the Transformer) —
1. Split the image into image patches.
2. Process patches through the linear projection layer to get initial patch embeddings.
3. Preappend trainable “class” embedding to patch embeddings.
4. Sum patch embeddings and learned positional embeddings.

A linear projection layer is used to map the image patch “arrays” to patch embedding “vectors”. The linear projection layer is a single Feed Forward layer to get a linear patch projection. This Feed Forward layer contains the embedding matrix E as weight. This matrix E is learnt during training. Image patches are treated the same way as tokens (words) in an NLP application. We train the model on image classification in supervised fashion.

Source : https://arxiv.org/pdf/2010.11929.pdf

Position embeddings are added to the patch embeddings to retain positional information. We use standard learnable 1D position embeddings, since we have not observed significant performance gains from using more advanced 2D-aware position embeddings (while there is a large gap between the performances of the model with no positional embedding and models with positional embedding). The Transformer encoder (just like the vanilla Transformer) consists of alternating layers of multiheaded self-attention(MSA) and MLP blocks. Layernorm(LN) is applied before every block,and residual connections after every block. The MLP contains two layers with a Gaussian Error Linear Unit (GELU) non-linearity .

https://pytorch.org/docs/stable/generated/torch.nn.GELU.html

What are the variants?
The “Base” and “Large” models are directly adopted from BERT and we add the larger “Huge” model. Brief notation to indicate the model size and the input patch size : for instance , ViT-L/16 means the “Large” variant with 16×16 input patch size (and hence the name of the paper 😜).

What are the training & Fine-tuning configs ?

We train all models, including ResNets, using Adam with β1=0.9, β2=0.999, a batch size of 4096 and apply a high weight decay of 0.1. They used a linear learning rate warmup and decay. All training is done on resolution 224.For fine-tuning we use SGD with momentum, batch size 512, for all models. They run a small grid search over learning rates, see learning rate. The Vision Transformer can handle arbitrary sequence lengths (upto memory constraints).

But how does it perform ?
When trained on mid-sized datasets such as ImageNet without strong regularization, these models yield modest accuracies of a few percentage points below ResNets of comparable size. However, the picture changes if the models are trained on larger datasets (14M-300M images) like JFT-300M dataset.

https://arxiv.org/pdf/2010.11929.pdf

We see that the performance of ViT-L/16 and H/14 beat ResNet and EfficientNet on almost all benchmarks with much lesser TPUv3 core-days which is essentially days required to train using 1 TPUv3.

Additional things to know ?

Left: Filters of the initial linear embedding of RGB values of ViT-L/32. (This is very similar to how a CNN learns)
Center: Similarity of learnt position embeddings of ViT-L/32. Tiles show the cosine similarity between the position embedding of the patch with the indicated row and column and the position embeddings of all other patches. This show that just with 1D embeddings the relative positioning is learnt.
Right: Size of attended area by head and network depth. Each dot shows the mean attention distance across images for one of 16 heads at one layer. (Usually in CNNs the shallow layers work on smaller parts and as we go deeper it works on larger parts, but in ViT it attends to both smaller and larger parts in terms of mean distance)

Lets dive into some code 🧑‍💻 —

# install transformers and datasets from huggingface
! pip install datasets transformers

import torch
import numpy as np
from datasets import load_dataset, load_metric
# few transformations are applied to images before feeding to ViT
from transformers import ViTImageProcessor, ViTForImageClassification,
TrainingArguments, Trainer



# We'll use the beans dataset, which is a collection of pictures of healthy
# and unhealthy bean leaves. 🍃
ds = load_dataset('beans')
ds['train'][400]

# lets look at what a sample contains
> {
'image': <PIL.JpegImagePlugin ...>,
'image_file_path': '/root/.cache/.../bean_rust_train.4.jpg',
'labels': 1
}

# lets look at the labels
labels = ds['train'].features['labels']
> ClassLabel(num_classes=3, names=['angular_leaf_spot', 'bean_rust',
'healthy'], names_file=None, id=None)
A grid of a few examples from each class in the dataset
## Helper Functions
# transforming the dataset
def transform(example_batch):
# Take a list of PIL images and turn them to pixel values
inputs = processor([x for x in example_batch['image']], return_tensors='pt')
# Don't forget to include the labels!
inputs['labels'] = example_batch['labels']
return inputs

# collate_fn will return a batch dict,
def collate_fn(batch):
return {
'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
'labels': torch.tensor([x['labels'] for x in batch])
}

#load the metric
def compute_metrics(p):
metric = load_metric("accuracy")
return metric.compute(predictions=np.argmax(p.predictions, axis=1),
references=p.label_ids)

# loading the model transformer and labels
model_name_or_path = 'google/vit-base-patch16-224-in21k'
processor = ViTImageProcessor.from_pretrained(model_name_or_path)
prepared_ds = ds.with_transform(transform)
labels = ds['train'].features['labels'].names

#loading the model
model = ViTForImageClassification.from_pretrained(
model_name_or_path,
num_labels=len(labels),
id2label={str(i): c for i, c in enumerate(labels)},
label2id={c: str(i) for i, c in enumerate(labels)}
)

Now we are ready to start training!

# setting transformer training arguments
training_args = TrainingArguments(
output_dir="./vit-base-beans",
per_device_train_batch_size=16,
evaluation_strategy="steps",
num_train_epochs=4,
fp16=True,
save_steps=100,
eval_steps=100,
logging_steps=10,
learning_rate=2e-4,
save_total_limit=2,
remove_unused_columns=False,
push_to_hub=False,
report_to='tensorboard',
load_best_model_at_end=True,
)
# setting trainer config
trainer = Trainer(
model=model,
args=training_args,
data_collator=collate_fn,
compute_metrics=compute_metrics,
train_dataset=prepared_ds["train"],
eval_dataset=prepared_ds["validation"],
tokenizer=processor,
)
# train the ViT model, log and and save metrics for train and val
train_results = trainer.train()
trainer.save_model()
#train
trainer.log_metrics("train", train_results.metrics)
trainer.save_metrics("train", train_results.metrics)
trainer.save_state()
#val
metrics = trainer.evaluate(prepared_ds['validation'])
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)

Appendix —
1. https://huggingface.co/blog/fine-tune-vit — Huggingface code reference
2. https://arxiv.org/pdf/2010.11929.pdf — Actual Paper
3. https://www.youtube.com/watch?v=j6kuz_NqkG0 by Aleksa Gordić — The AI Epiphany

And this is how Transformer are overthrowing even CNNs after they have done it to LSTMs. ViT was also used by OpenAI to train CLIP models which is a open-source, multi-modal (text and images), zero-shot state-of-the-art model.

In case you found this information helpful, do clap for the blog 🙈

--

--