Helpful context:


To find the $k$ nearest neighbors of a point, you need a distance function. By default, everyone uses Euclidean distance:

$$d(x, y) = \sqrt{\sum_{i=1}^d (x_i - y_i)^2}$$

This distance treats every dimension as equally important. If your features are pixel values in a 100x100 image, Euclidean distance sums up squared differences across all 10,000 pixels. The result is problematic in a precise way: Euclidean distance between two photos of the same face taken under different lighting conditions can be enormous (every pixel changed). Euclidean distance between two different people’s faces photographed under similar conditions might be small (similar average pixel values, similar lighting). The distance function disagrees with the human notion of “same person.”

The core question: instead of assuming a distance function, can we learn one from data?

What Makes a Valid Distance Metric

A valid distance metric is a function $d: \mathcal{X} \times \mathcal{X} \to \mathbb{R}$ satisfying three properties:

  1. Non-negativity and identity: $d(x, y) \geq 0$, with $d(x, y) = 0$ if and only if $x = y$.
  2. Symmetry: $d(x, y) = d(y, x)$.
  3. Triangle inequality: $d(x, z) \leq d(x, y) + d(y, z)$.

The triangle inequality is the most substantive: it says you cannot have a shorter path from $x$ to $z$ by going via $y$ than by going directly.

The Mahalanobis distance is the most natural generalization of Euclidean distance that can be learned:

$$d_M(x, y) = \sqrt{(x - y)^\top M (x - y)}$$

Here $M$ is a $d \times d$ positive semi-definite matrix. When $M = I$ (the identity matrix), this reduces exactly to Euclidean distance. When $M$ is diagonal, it scales each dimension independently - a weighted Euclidean distance. When $M$ is a full positive semi-definite matrix, it can rotate and scale the feature space in any direction. Because $M$ is positive semi-definite, all three metric properties are satisfied automatically.

Intuition for what $M$ does. Any positive semi-definite matrix $M$ can be written as $M = L^\top L$ for some matrix $L$. Then $d_M(x, y) = |L(x-y)|_2$. So the Mahalanobis distance is just the Euclidean distance after applying the linear transformation $L$ to the data. Learning $M$ is equivalent to learning a linear transformation of the feature space where Euclidean distance is meaningful.

Learning the matrix $M$ is metric learning.

Linear Metric Learning - LMNN

Large Margin Nearest Neighbor (LMNN) learning is a classical metric learning algorithm. The idea is intuitive: for each training point $x_i$, define its $k$ same-class neighbors as its target neighbors - the points it should be close to. Then optimize $M$ to satisfy two constraints:

  1. Target neighbors should be pulled close to $x_i$ (small $d_M(x_i, x_j)$ for same-class $j$).
  2. Points of a different class should be pushed away from $x_i$ by at least a margin $\delta$ relative to its target neighbors (no impostor should be closer to $x_i$ than a target neighbor).

Formally, the objective is a semidefinite program (SDP) - an optimization problem with a linear objective and a constraint that $M$ must be positive semi-definite. SDPs can be solved exactly but scale poorly. For a dataset of $n$ points in $d$ dimensions, LMNN’s SDP becomes impractical for $n$ beyond a few tens of thousands.

LMNN works well when the data is low-dimensional and the class structure is locally consistent. It is an important baseline and its formulation makes the intuition of metric learning concrete.

Siamese Networks

The deep learning approach to metric learning replaces the linear transformation $L$ with a neural network $f_\theta$. Two copies of the same network - sharing all weights - process two inputs and produce embeddings. The distance between the embeddings determines whether the inputs are similar.

Architecture. A Siamese network takes two inputs $x_a$ and $x_b$, passes each through the same neural network $f_\theta$ (same weights, applied independently), and produces embeddings $z_a = f_\theta(x_a)$ and $z_b = f_\theta(x_b)$. The distance $d(z_a, z_b) = |z_a - z_b|_2$ is used as a similarity score.

Weight sharing is critical: using the same weights for both inputs guarantees that the distance is symmetric. It also means the network generalizes - the same function is applied to every input, so a representation learned from one example transfers to all inputs.

Training with contrastive loss. Training pairs are labeled “same class” or “different class.” The contrastive loss is:

$$\mathcal{L} = y \cdot d(z_a, z_b)^2 + (1 - y) \cdot \max(0, m - d(z_a, z_b))^2$$

Here $y = 1$ if the pair is same-class and $y = 0$ if different-class, and $m$ is a margin. For same-class pairs ($y=1$), the loss pushes the embeddings together by minimizing distance. For different-class pairs ($y=0$), the loss pushes embeddings apart, but only up to margin $m$ - once the embeddings are farther than $m$, the loss is zero (no need to push them further apart).

The network learns a feature representation where semantically similar inputs are nearby and dissimilar inputs are far away. The final layer of the network is the embedding, not a class label. This distinction matters: the output is a location in a learned metric space, not a softmax over fixed classes.

Triplet Loss

Contrastive loss works but has a practical weakness: the quality of training depends heavily on which pairs you choose. Random pairs are often too easy (very different examples are already far apart) and contribute zero gradient.

Triplet loss addresses this by training on triplets instead of pairs. A triplet $(a, p, n)$ consists of:

  • An anchor $a$ - the reference point.
  • A positive $p$ - a different example from the same class as $a$.
  • A negative $n$ - an example from a different class.

The loss enforces that the anchor is closer to the positive than to the negative by at least a margin $\alpha$:

$$\mathcal{L} = \max!\left(0,\ d(f(a), f(p))^2 - d(f(a), f(n))^2 + \alpha\right)$$

If the anchor is already closer to the positive than the negative by more than $\alpha$, the loss is zero - this triplet is “easy” and contributes nothing. Gradient only flows for triplets where the constraint is violated or nearly so.

Hard negative mining. For a given anchor and positive, choose the negative that is closest to the anchor in the current embedding space - the “hardest” negative. This ensures that most triplets in a batch are “hard” and contribute meaningful gradient. Without hard negative mining, most triplets are easy and training stagnates. In practice, semi-hard negatives (negatives farther than the positive but within the margin) are often more stable than the absolute hardest negatives.

Worked example. Anchor $a$ is a photo of person A. Positive $p$ is a different photo of person A. Negative $n$ is a photo of person B. The triplet loss says: the embedding of $a$ must be closer to $p$ than to $n$ by at least $\alpha = 0.2$. If $d(f(a), f(p)) = 0.5$ and $d(f(a), f(n)) = 0.6$, the loss is $\max(0, 0.5 - 0.6 + 0.2) = \max(0, 0.1) = 0.1$. There is a constraint violation and the loss is positive. If instead $d(f(a), f(n)) = 0.8$, the loss is $\max(0, 0.5 - 0.8 + 0.2) = \max(0, -0.1) = 0$. The triplet is easy and contributes no gradient.

Applications

Face verification. Face verification asks: are these two photos the same person? Not “who is this person?” (face recognition) but “are these the same?” FaceNet (Schroff et al., Google, 2015) trains a deep CNN with triplet loss to produce 128-dimensional embeddings of face images. The training set contains millions of face images across hundreds of thousands of identities. At test time, two faces are embedded and their L2 distance is computed. If the distance is below a learned threshold, the system says “same person.” The system is not asked to recognize any specific identity - only to determine sameness. This allows it to verify identities it never saw during training.

Few-shot learning. Suppose you have 5 labeled examples of a new animal species and want to classify unseen images. A standard classifier cannot handle new classes without retraining. But in a good metric space, you can classify by nearest-neighbor: embed the 5 examples and the query image, and assign the query to the class whose examples are closest. Prototypical Networks extend this: compute the mean embedding (“prototype”) of each class’s few examples, and classify by nearest prototype. This works because the metric space generalizes - the embedding function learned over many training classes produces an embedding where new classes also cluster.

Person re-identification. A surveillance system with cameras at different locations needs to track whether the same person appears in multiple cameras. Raw pixel distance fails completely (different lighting, angle, occlusion). A metric-learned embedding, trained on pairs of images from the same and different identities, produces representations that are robust to these variations. The same re-id model can be deployed to a new location without retraining.

Signature verification. Determine if two handwritten signatures were signed by the same person. The training set contains pairs of genuine-genuine and genuine-forgery pairs. A Siamese network learns an embedding where genuine signatures of the same person are close and forgeries are far from genuine signatures.

Why This Differs From Training a Classifier

A softmax classifier learns a fixed set of $C$ output classes. The final layer maps the embedding to $C$ scores. Everything the classifier knows is about those $C$ classes. At test time, it can only assign inputs to one of those classes.

A metric learner learns a similarity function. Its output is an embedding, not a class label. Because the network has learned what similarity means in general - not which specific classes to recognize - it generalizes to new classes at test time. Give it 5 examples of a class it never saw and it can tell you which new examples belong to that class by distance. This is the fundamental reason metric learning is useful for few-shot and open-set recognition problems: the learned distance function transfers; the class boundaries do not.


Concept Key point
Metric learning Learn a distance function from data rather than assuming Euclidean distance
Mahalanobis distance $\sqrt{(x-y)^\top M (x-y)}$; valid metric for any positive semi-definite $M$
LMNN Linear metric learning; pull target neighbors close, push impostors away by a margin
Siamese network Two weight-shared copies; output is an embedding; distance determines similarity
Contrastive loss Minimize distance for same-class pairs; maximize (up to margin) for different-class pairs
Triplet loss Force anchor closer to positive than negative by margin $\alpha$; hard negative mining matters
Few-shot learning Prototypical Networks classify by nearest prototype in learned metric space
Generalization Metric learning transfers to unseen classes; classifiers do not

Read Next: