# coding: utf-8
# Copyright (c) 2019-2021. The MITRE Corporation.
"""
Core Neural Net architectures for topic modeling.
"""
import math
import os
import numpy as np
import logging
from tmnt.distribution import LogisticGaussianDistribution
from tmnt.distribution import BaseDistribution
from torch import nn
from torch.nn.modules.loss import _Loss
import torch
from torch.distributions.categorical import Categorical
from typing import List, Tuple, Dict, Optional, Union, NoReturn
[docs]class BaseVAE(nn.Module):
def __init__(self, vocab_size=2000, latent_distribution=LogisticGaussianDistribution(100, 20),
coherence_reg_penalty=0.0, redundancy_reg_penalty=0.0,
n_covars=0, device='cpu', **kwargs):
super(BaseVAE, self).__init__(**kwargs)
self.vocab_size = vocab_size
self.n_latent = latent_distribution.n_latent
self.enc_size = latent_distribution.enc_size
self.coherence_reg_penalty = coherence_reg_penalty
self.redundancy_reg_penalty = redundancy_reg_penalty
self.n_covars = n_covars
self.device = device
self.embedding = None
self.latent_distribution = latent_distribution
self.decoder = nn.Linear(self.n_latent, self.vocab_size).to(device)
#self.coherence_regularization = CoherenceRegularizer(self.coherence_reg_penalty, self.redundancy_reg_penalty)
def initialize_bias_terms(self, wd_freqs: Optional[np.ndarray]):
if wd_freqs is not None:
freq_nd = wd_freqs + 1 # simple smoothing
log_freq = np.log(freq_nd) - np.log(freq_nd.sum())
with torch.no_grad():
self.decoder.bias = nn.Parameter(torch.tensor(log_freq, dtype=torch.float32, device=self.device))
self.decoder.bias.requires_grad_(False)
[docs] def get_ordered_terms(self):
"""
Returns the top K terms for each topic based on sensitivity analysis. Terms whose
probability increases the most for a unit increase in a given topic score/probability
are those most associated with the topic.
"""
z = torch.ones((self.n_latent,), device=self.device)
jacobian = torch.autograd.functional.jacobian(self.decoder, z)
sorted_j = jacobian.argsort(dim=0, descending=True)
return sorted_j.cpu().numpy()
[docs] def get_topic_vectors(self):
"""
Returns unnormalized topic vectors
"""
z = torch.ones((1, self.n_latent), device=self.device)
jacobian = torch.autograd.functional.jacobian(self.decoder, z)
return jacobian.cpu().asnumpy()
def add_coherence_reg_penalty(self, cur_loss):
if self.coherence_reg_penalty > 0.0 and self.embedding is not None:
w = self.decoder.weight.data
emb = self.embedding.weight.data
c, d = self.coherence_regularization(w, emb)
return (cur_loss + c + d), c, d
else:
return (cur_loss, torch.zeros_like(cur_loss, device=self.device), torch.zeros_like(cur_loss, device=self.device))
def get_loss_terms(self, data, y, KL, batch_size):
rr = data * torch.log(y+1e-12)
recon_loss = -(rr.sum(dim=1))
i_loss = KL + recon_loss
ii_loss, coherence_loss, redundancy_loss = self.add_coherence_reg_penalty(i_loss)
return ii_loss, recon_loss, coherence_loss, redundancy_loss
[docs]class BowVAEModel(BaseVAE):
"""
Defines the neural architecture for a bag-of-words topic model.
Parameters:
enc_dim (int): Number of dimension of input encoder (first FC layer)
embedding_size (int): Number of dimensions for embedding layer
n_encoding_layers (int): Number of layers used for the encoder. (default = 1)
enc_dr (float): Dropout after each encoder layer. (default = 0.1)
n_covars (int): Number of values for categorical co-variate (0 for non-CovariateData BOW model)
device (str): context device
"""
def __init__(self,
enc_dim, embedding_size, n_encoding_layers, enc_dr,
n_labels=0,
gamma=1.0,
multilabel=False,
classifier_dropout=0.1,
*args, **kwargs):
super(BowVAEModel, self).__init__(*args, **kwargs)
self.embedding_size = embedding_size
self.num_enc_layers = n_encoding_layers
self.enc_dr = enc_dr
self.enc_dim = enc_dim
self.multilabel = multilabel
self.n_labels = n_labels
self.gamma = gamma
self.classifier_dropout=classifier_dropout
self.has_classifier = self.n_labels > 1
self.encoding_dims = [self.embedding_size + self.n_covars] + [enc_dim for _ in range(n_encoding_layers)]
self.embedding = torch.nn.Sequential()
self.embedding.add_module("linear", torch.nn.Linear(self.vocab_size, self.embedding_size))
self.embedding.add_module("tanh", torch.nn.Tanh())
self.encoder = self._get_encoder(self.encoding_dims, dr=enc_dr)
if self.has_classifier:
self.lab_dr = torch.nn.Dropout(self.classifier_dropout)
self.classifier = torch.nn.Linear(self.n_latent, self.n_labels, bias=True)
self.apply(self._init_weights)
def _init_weights(self, module):
if isinstance(module, torch.nn.Linear):
torch.nn.init.xavier_uniform_(module.weight.data)
def _get_encoder(self, dims, dr=0.1):
encoder = torch.nn.Sequential()
for i in range(len(dims)-1):
encoder.add_module("linear_"+str(i), torch.nn.Linear(dims[i], dims[i+1]))
encoder.add_module("soft_"+str(i), torch.nn.Softplus())
if dr > 0.0:
encoder.add_module("drop_"+str(i), torch.nn.Dropout(dr))
return encoder
def get_ordered_terms_encoder(self, dataloader, sample_size=-1):
jacobians = np.zeros(shape=(self.n_latent, self.vocab_size))
samples = 0
def partial_network(data):
emb_out = self.embedding(data)
enc_out = self.latent_distribution.get_mu_encoding(self.encoder(emb_out), include_bn=True)
return enc_out
for bi, (data, _) in enumerate(dataloader):
if sample_size > 0 and samples >= sample_size:
print("Sample processed, exiting..")
break
samples += data.shape[0]
x_data = x_data.to(device = self.device)
x_data = torch.minimum(x_data, torch.tensor([1.0], device=self.device))
jacobian = torch.autograd.functional.jacobian(partial_network, x_data)
ss = jacobian.sum(dim=0).numpy()
jacobian[bi] += ss
sorted_j = (- jacobians).argsort(dim=1).transpose()
return sorted_j
def get_ordered_terms_per_item(self, dataloader, sample_size=-1):
jacobian_list = [[] for i in range(self.n_latent)]
samples = 0
def partial_network(data):
emb_out = self.embedding(data)
enc_out = self.latent_distribution.get_mu_encoding(self.encoder(emb_out), include_bn=True)
return enc_out
for bi, (data, _) in enumerate(dataloader):
if sample_size > 0 and samples >= sample_size:
print("Sample processed, exiting..")
break
samples += data.shape[0]
x_data = x_data.to(device = self.device)
x_data = torch.minimum(x_data, torch.Tensor([1.0], device=self.device))
jacobian = torch.autograd.functional.jacobian(partial_network, x_data)
ss = jacobian.numpy()
jacobian_list[bi] += list(ss)
return jacobian_list
[docs] def encode_data(self, data, include_bn=True):
"""
Encode data to the mean of the latent distribution defined by the input `data`.
Parameters
----------
data: `mxnet.ndarray.NDArray` or `mxnet.symbol.Symbol`
input data of shape (batch_size, vocab_size)
Returns
-------
`mxnet.ndarray.NDArray` or `mxnet.symbol.Symbol`
Result of encoding with shape (batch_size, n_latent)
"""
return self.latent_distribution.get_mu_encoding(self.encoder(self.embedding(data)), include_bn=include_bn)
def run_encode(self, in_data, batch_size):
enc_out = self.encoder(in_data)
return self.latent_distribution(enc_out, batch_size)
[docs] def predict(self, data):
"""Predict the label given the input data (ignoring VAE reconstruction)
Parameters:
data (tensor): input data tensor
Returns:
output vector (tensor): unnormalized outputs over label values
"""
return self.classifier(self.lab_dr(self.encode_data(data)))
[docs] def forward(self, data):
data = data.to_dense()
batch_size = data.shape[0]
emb_out = self.embedding(data)
#z, KL = self.run_encode(F, emb_out, batch_size)
enc_out = self.encoder(emb_out)
z, KL = self.latent_distribution(enc_out, batch_size)
xhat = self.decoder(z)
y = torch.nn.functional.softmax(xhat, dim=1)
ii_loss, recon_loss, coherence_loss, redundancy_loss = \
self.get_loss_terms(data, y, KL, batch_size)
if self.has_classifier:
mu_out = self.latent_distribution.get_mu_encoding(enc_out)
classifier_outputs = self.classifier(self.lab_dr(mu_out))
else:
classifier_outputs = None
return ii_loss, KL, recon_loss, coherence_loss, redundancy_loss, classifier_outputs
[docs]class MetricBowVAEModel(BowVAEModel):
def __init__(self, *args, **kwargs):
self.kld_wt = 1.0
super(MetricBowVAEModel, self).__init__(*args, **kwargs)
def get_redundancy_penalty(self):
w = self.decoder.weight.data
emb = self.embedding.weight.data if self.embedding is not None else w.transpose()
_, redundancy_loss = self.coherence_regularization(w, emb)
return redundancy_loss
def _get_elbo(self, bow, enc):
batch_size = bow.shape[0]
z, KL = self.latent_distribution(enc, batch_size)
KL_loss = (KL * self.kld_wt)
y = torch.nn.functional.softmax(self.decoder(z), dim=1)
rec_loss = -torch.sum( bow * torch.log(y+1e-12), dim=1 )
elbo = rec_loss + KL_loss
return elbo, rec_loss, KL_loss
def _get_encoding(self, data):
return self.encoder( self.embedding(data) )
def unpaired_input_forward(self, data):
enc = self._get_encoding(data)
elbo, rec_loss, kl_loss = self._get_elbo(data, enc)
redundancy_loss = self.get_redundancy_penalty()
return elbo, rec_loss, kl_loss, redundancy_loss
[docs] def forward(self, F, data1, data2):
enc1 = self._get_encoding(data1)
enc2 = self._get_encoding(data2)
mu1 = self.latent_distribution.get_mu_encoding(enc1)
mu2 = self.latent_distribution.get_mu_encoding(enc2)
elbo1, rec_loss1, KL_loss1 = self._get_elbo(data1, enc1)
elbo2, rec_loss2, KL_loss2 = self._get_elbo(data2, enc2)
redundancy_loss = self.get_redundancy_penalty()
return (elbo1 + elbo2), (rec_loss1 + rec_loss2), (KL_loss1 + KL_loss2), redundancy_loss, mu1, mu2
[docs]class CovariateBowVAEModel(BowVAEModel):
"""Bag-of-words topic model with labels used as co-variates
"""
def __init__(self, covar_net_layers=1, *args, **kwargs):
super(CovariateBowVAEModel, self).__init__(*args, **kwargs)
self.covar_net_layers = covar_net_layers
with self.name_scope():
if self.n_covars < 1:
self.cov_decoder = ContinuousCovariateModel(self.n_latent, self.vocab_size,
total_layers=self.covar_net_layers, device=self.device)
else:
self.cov_decoder = CovariateModel(self.n_latent, self.n_covars, self.vocab_size,
interactions=True, device=self.device)
[docs] def encode_data_with_covariates(self, data, covars, include_bn=False):
"""
Encode data to the mean of the latent distribution defined by the input `data`
"""
emb_out = self.embedding(data)
enc_out = self.encoder(mx.nd.concat(emb_out, covars))
return self.latent_distribution.get_mu_encoding(enc_out, include_bn=include_bn)
[docs] def get_ordered_terms_with_covar_at_data(self, data, k, covar):
"""
Uses test/training data-point as the input points around which term sensitivity is computed
"""
data = data.to(self.device)
covar = covar.to(self.device)
jacobian = torch.zeros((self.vocab_size, self.n_latent), device=self.device)
batch_size = data.shape[0]
emb_out = self.embedding(data)
co_emb = torch.cat(emb_out, covar)
z = self.latent_distribution.get_mu_encoding(self.encoder(co_emb))
z.attach_grad()
outputs = []
with mx.autograd.record():
dec_out = self.decoder(z)
cov_dec_out = self.cov_decoder(z, covar)
y = mx.nd.softmax(cov_dec_out + dec_out, axis=1)
for i in range(self.vocab_size):
outputs.append(y[:,i])
for i, output in enumerate(outputs):
output.backward(retain_graph=True)
jacobian[i] += z.grad.sum(axis=0)
sorted_j = jacobian.argsort(axis=0, is_ascend=False)
return sorted_j
[docs] def get_topic_vectors(self, data, covar):
"""
Returns unnormalized topic vectors based on the input data
"""
data = data.as_in_context(self.model_ctx)
covar = covar.as_in_context(self.model_ctx)
jacobian = mx.nd.zeros(shape=(self.vocab_size, self.n_latent), ctx=self.model_ctx)
batch_size = data.shape[0]
emb_out = self.embedding(data)
co_emb = mx.nd.concat(emb_out, covar)
z = self.latent_distribution.get_mu_encoding(self.encoder(co_emb))
z.attach_grad()
outputs = []
with mx.autograd.record():
dec_out = self.decoder(z)
cov_dec_out = self.cov_decoder(z, covar)
y = mx.nd.softmax(cov_dec_out + dec_out, axis=1)
for i in range(self.vocab_size):
outputs.append(y[:,i])
for i, output in enumerate(outputs):
output.backward(retain_graph=True)
jacobian[i] += z.grad.sum(axis=0)
return jacobian
[docs] def forward(self, F, data, covars):
batch_size = data.shape[0]
emb_out = self.embedding(data)
if self.n_covars > 0:
covars = F.one_hot(covars, self.n_covars)
co_emb = F.concat(emb_out, covars)
z, KL = self.run_encode(F, co_emb, batch_size)
dec_out = self.decoder(z)
cov_dec_out = self.cov_decoder(z, covars)
y = F.softmax(dec_out + cov_dec_out, axis=1)
ii_loss, recon_loss, coherence_loss, redundancy_loss = \
self.get_loss_terms(F, data, y, KL, batch_size)
return ii_loss, KL, recon_loss, coherence_loss, redundancy_loss, None
[docs]class CovariateModel(nn.Module):
def __init__(self, n_topics, n_covars, vocab_size, interactions=False, device='cpu'):
self.n_topics = n_topics
self.n_covars = n_covars
self.vocab_size = vocab_size
self.interactions = interactions
self.device = device
super(CovariateModel, self).__init__()
with self.name_scope():
self.cov_decoder = torch.nn.Linear(n_covars, self.vocab_size, bias=False)
if self.interactions:
self.cov_inter_decoder = torch.nn.Linear(self.n_covars * self.n_topics, self.vocab_size, bias=False)
self.apply(self._init_weights)
def _init_weights(self, module):
if isinstance(module, torch.nn.Linear):
torch.nn.init.xavier_uniform_(module.weight.data)
[docs] def forward(self, topic_distrib, covars):
score_C = self.cov_decoder(covars)
if self.interactions:
td_rsh = topic_distrib.unsqueeze(1)
cov_rsh = covars.unsqueeze(2)
cov_interactions = cov_rsh * td_rsh ## shape (N, Topics, Covariates) -- outer product
batch_size = cov_interactions.shape[0]
cov_interactions_rsh = torch.reshape(cov_interactions, (batch_size, self.n_topics * self.n_covars))
score_CI = self.cov_inter_decoder(cov_interactions_rsh)
return score_CI + score_C
else:
return score_C
[docs]class ContinuousCovariateModel(nn.Module):
def __init__(self, n_topics, vocab_size, total_layers = 1, device='device'):
self.n_topics = n_topics
self.n_scalars = 1 # number of continuous variables
self.model_ctx = ctx
self.time_topic_dim = 300
super(ContinuousCovariateModel, self).__init__()
with self.name_scope():
self.cov_decoder = nn.Sequential()
for i in range(total_layers):
if i < 1:
in_units = self.n_scalars + self.n_topics
else:
in_units = self.time_topic_dim
self.cov_decoder.add_module("linear_"+str(i), nn.Linear(in_units, self.time_topic_dim,
bias=(i < 1)))
self.cov_decoder.add_module("relu_"+str(i), nn.Relu())
self.cov_decoder.add_module("linear_out_", nn.Linear(self.time_topic_dim, vocab_size, bias=False))
[docs] def forward(self, topic_distrib, scalars):
inputs = torch.cat((topic_distrib, scalars), 0)
sc_transform = self.cov_decoder(inputs)
return sc_transform
[docs]class CoherenceRegularizer(nn.Module):
## Follows paper to add coherence loss: http://aclweb.org/anthology/D18-1096
def __init__(self, coherence_pen=1.0, redundancy_pen=1.0):
super(CoherenceRegularizer, self).__init__()
self.coherence_pen = coherence_pen
self.redundancy_pen = redundancy_pen
[docs] def forward(self, w, emb):
## emb should have shape (D x V)
## w should have shape (V x K)
# w NORM over columns
w_min,_ = w.min(keepdim=True, dim=0)
ww = w - w_min # ensure weights are non-negative
w_norm_val = torch.norm(ww, keepdim=True, dim=0)
emb_norm_val = torch.norm(emb, keepdim=True, dim=1)
w_norm = ww / w_norm_val
emb_norm = emb / emb_norm_val
T = torch.matmul(emb_norm, w_norm)
T_norm_vals = torch.norm(T, keepdim=True, dim=0)
T_norm = T / T_norm_vals # (D x K)
S = torch.matmul(emb_norm.t(), T_norm) # (V x K)
C = -(S * w_norm).sum()
## diversity component
D1 = torch.matmul(T_norm.t(), T_norm)
D = D1.sum()
return C * self.coherence_pen , D * self.redundancy_pen
[docs]class BaseSeqBowVED(BaseVAE):
def __init__(self,
llm,
latent_dist,
num_classes=0,
dropout=0.0,
vocab_size=2000,
kld=0.1,
device='cpu',
use_pooling=True,
entropy_loss_coef=1000.0,
redundancy_reg_penalty=0.0, pre_trained_embedding = None):
super(BaseSeqBowVED, self).__init__(device=device, vocab_size=vocab_size)
self.n_latent = latent_dist.n_latent
self.llm = llm
self.kld_wt = kld
self.has_classifier = num_classes >= 2
self.num_classes = num_classes
self.dropout = dropout
self.redundancy_reg_penalty = redundancy_reg_penalty
self.latent_distribution = latent_dist
self.embedding = None
self.decoder = nn.Linear(self.n_latent, vocab_size, bias=True).to(device)
self.coherence_regularization = CoherenceRegularizer(0.0, self.redundancy_reg_penalty)
self.use_pooling = use_pooling
self.entropy_loss_coef = entropy_loss_coef
if pre_trained_embedding is not None:
self.embedding = nn.Linear(len(pre_trained_embedding.idx_to_vec),
pre_trained_embedding.idx_to_vec[0].size, bias=False)
#self.apply(self._init_weights)
self.latent_distribution.apply(self._init_weights)
self.decoder.apply(self._init_weights)
def _init_weights(self, module):
if isinstance(module, torch.nn.Linear):
torch.nn.init.xavier_uniform_(module.weight.data)
def _get_embedding(self, model_output, attention_mask):
if self.use_pooling:
token_embeddings = model_output.last_hidden_state #First element of model_output contains all token embeddings
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
else:
return model_output.last_hidden_state[:,0,:]
[docs] def get_ordered_terms(self):
"""
Returns the top K terms for each topic based on sensitivity analysis. Terms whose
probability increases the most for a unit increase in a given topic score/probability
are those most associated with the topic.
"""
z = torch.ones((self.n_latent,), device=self.device)
jacobian = torch.autograd.functional.jacobian(self.decoder, z)
sorted_j = jacobian.argsort(dim=0, descending=True)
return sorted_j.cpu().numpy()
def get_redundancy_penalty(self):
w = self.decoder.weight.data
emb = self.embedding.weight.data if self.embedding is not None else w.transpose(0,1)
_, redundancy_loss = self.coherence_regularization(w, emb)
return redundancy_loss
def _get_latent_sparsity_term(self, encoding):
as_distribution = Categorical(probs=encoding) if self.latent_distribution.on_simplex else Categorical(logits=encoding)
return as_distribution.entropy()
def _get_elbo(self, bow, enc):
z, KL = self.latent_distribution(enc, bow.size()[0])
KL_loss = (KL * self.kld_wt)
dec = self.decoder(z)
y = torch.nn.functional.softmax(dec, dim=1)
rec_loss = -torch.sum( bow.to_dense() * torch.log(y+1e-12), dim=1 )
entropy_loss = self._get_latent_sparsity_term(z)
elbo = rec_loss + KL_loss + (entropy_loss * self.entropy_loss_coef)
return elbo, rec_loss, KL_loss, entropy_loss
def forward_encode(self, input_ids, attention_mask):
llm_output = self.llm(input_ids, attention_mask)
cls_vec = self._get_embedding(llm_output, attention_mask)
return self.latent_distribution.get_mu_encoding(cls_vec)
[docs]class SeqBowVED(BaseSeqBowVED):
def __init__(self, *args, **kwargs):
super(SeqBowVED, self).__init__(*args, **kwargs)
if self.has_classifier:
self.classifier = torch.nn.Sequential()
self.classifier.add_module("dr", nn.Dropout(self.dropout).to(self.device))
self.classifier.add_module("l_out", nn.Linear(self.n_latent, self.num_classes).to(self.device))
[docs] def forward(self, input_ids, attention_mask, bow=None): # pylint: disable=arguments-differ
llm_output = self.llm(input_ids, attention_mask)
cls_vec = self._get_embedding(llm_output, attention_mask)
return self.forward_with_cached_encoding(cls_vec, bow)
def forward_with_cached_encoding(self, enc, bow):
elbo, rec_loss, KL_loss = 0.0, 0.0, 0.0
if bow is not None:
elbo, rec_loss, KL_loss, entropy_loss = self._get_elbo(bow, enc)
if self.has_classifier:
z_mu = self.latent_distribution.get_mu_encoding(enc)
classifier_outputs = self.classifier(z_mu)
else:
classifier_outputs = None
redundancy_loss = entropy_loss #self.get_redundancy_penalty()
elbo = elbo + redundancy_loss
return elbo, rec_loss, KL_loss, redundancy_loss, classifier_outputs
[docs]class MetricSeqBowVED(BaseSeqBowVED):
def __init__(self, *args, **kwargs):
super(MetricSeqBowVED, self).__init__(*args, **kwargs)
def unpaired_input_forward(self, in1, mask1, bow1):
llm_output = self.llm(in1, mask1)
cls_vec = self._get_embedding(llm_output, mask1)
elbo1, rec_loss1, KL_loss1, entropy_loss = self._get_elbo(bow1, cls_vec)
redundancy_loss = entropy_loss # self.get_redundancy_penalty()
return elbo1, rec_loss1, KL_loss1, redundancy_loss
[docs] def forward(self, in1, mask1, bow1, in2, mask2, bow2):
llm_out1 = self.llm(in1, mask1)
llm_out2 = self.llm(in2, mask2)
enc1 = self._get_embedding(llm_out1, mask1)
enc2 = self._get_embedding(llm_out2, mask2)
elbo1, rec_loss1, KL_loss1, entropy_loss1 = self._get_elbo(bow1, enc1)
elbo2, rec_loss2, KL_loss2, entropy_loss2 = self._get_elbo(bow2, enc2)
elbo = elbo1 + elbo2
rec_loss = rec_loss1 + rec_loss2
KL_loss = KL_loss1 + KL_loss2
#z_mu1 = self.latent_distribution.get_mu_encoding(enc2)
#z_mu2 = self.latent_distribution.get_mu_encoding(enc2)
redundancy_loss = entropy_loss1 + entropy_loss2 #self.get_redundancy_penalty()
#return elbo, rec_loss, KL_loss, redundancy_loss, z_mu1, z_mu2
return elbo, rec_loss, KL_loss, redundancy_loss, enc1, enc2
[docs]class GeneralizedSDMLLoss(_Loss):
r"""Calculates Batchwise Smoothed Deep Metric Learning (SDML) Loss given two input tensors and a smoothing weight
SDM Loss learns similarity between paired samples by using unpaired samples in the minibatch
as potential negative examples.
The loss is described in greater detail in
"Large Scale Question Paraphrase Retrieval with Smoothed Deep Metric Learning."
- by Bonadiman, Daniele, Anjishnu Kumar, and Arpit Mittal. arXiv preprint arXiv:1905.12786 (2019).
URL: https://arxiv.org/pdf/1905.12786.pdf
Parameters
----------
smoothing_parameter : float
Probability mass to be distributed over the minibatch. Must be < 1.0.
weight : float or None
Global scalar weight for loss.
batch_axis : int, default 0
The axis that represents mini-batch.
Inputs:
- **x1**: Minibatch of data points with shape (batch_size, vector_dim)
- **x2**: Minibatch of data points with shape (batch_size, vector_dim)
Each item in x1 is a positive sample for the items with the same label in x2
That is, x1[0] and x2[0] form a positive pair iff label(x1[0]) = label(x2[0])
All data points in different rows should be decorrelated
Outputs:
- **loss**: loss tensor with shape (batch_size,).
"""
def __init__(self, smoothing_parameter=0.3, weight=1., batch_axis=0, x2_downweight_idx=-1, **kwargs):
super(GeneralizedSDMLLoss, self).__init__(weight, batch_axis, **kwargs)
self.kl_loss = nn.KLDivLoss(size_average=False, reduction='batchmean')
self.smoothing_parameter = smoothing_parameter # Smoothing probability mass
self.x2_downweight_idx = x2_downweight_idx
def _compute_distances(self, x1, x2):
"""
This function computes the euclidean distance between every vector
in the two batches in input.
"""
# extracting sizes expecting [batch_size, dim]
assert x1.size() == x2.size()
batch_size, dim = x1.size()
# expanding both tensor form [batch_size, dim] to [batch_size, batch_size, dim]
x1_ = x1.unsqueeze(1).broadcast_to([batch_size, batch_size, dim])
x2_ = x2.unsqueeze(0).broadcast_to([batch_size, batch_size, dim])
# pointwise squared differences
squared_diffs = (x1_ - x2_)**2
# sum of squared differences distance
return squared_diffs.sum(axis=2)
def _compute_labels(self, l1: torch.Tensor, l2: torch.Tensor):
"""
Example:
l1 = [1,2,2]
l2 = [1,2,1]
===>
[ [ 1, 0, 1],
[ 0, 1, 0],
[ 0, 1, 0] ]
This is an outer product with the equality predicate.
"""
batch_size = l1.size()[0]
l1_x = l1.unsqueeze(1).expand(batch_size, batch_size)
l2_x = l2.unsqueeze(0).expand(batch_size, batch_size)
#l1_x = F.broadcast_to(F.expand_dims(l1, 1), (batch_size, batch_size))
#l2_x = F.broadcast_to(F.expand_dims(l2, 0), (batch_size, batch_size))
ll = torch.eq(l1_x, l2_x)
labels = ll * (1 - self.smoothing_parameter) + (~ll) * self.smoothing_parameter / (batch_size - 1)
## now normalize rows to sum to 1.0
labels = labels / labels.sum(axis=1,keepdim=True).expand(batch_size, batch_size)
if self.x2_downweight_idx >= 0:
#down_wt = len(mx.np.where(l2.as_np_ndarray != self.x2_downweight_idx)[0]) / batch_size
down_wt = len(np.where(l2 != self.x2_downweight_idx)[0]) / batch_size
else:
down_wt = 1.0
return labels, down_wt
def _loss(self, x1: torch.Tensor, l1: torch.Tensor, x2: torch.Tensor, l2: torch.Tensor):
"""
the function computes the kl divergence between the negative distances
and the smoothed label matrix.
"""
labels, wt = self._compute_labels(l1, l2)
distances = self._compute_distances(x1, x2)
log_probabilities = torch.log_softmax(-distances, dim=1)
# multiply by the batch size to obtain the sum loss (kl_loss averages instead of sum)
kl = self.kl_loss(log_probabilities, labels.to(distances.device)) * wt
return kl
[docs] def forward(self, x1, l1, x2, l2):
return self._loss(x1, l1, x2, l2)
[docs]class MultiNegativeCrossEntropyLoss(_Loss):
"""
Inputs:
- **x1**: Minibatch of data points with shape (batch_size, vector_dim)
- **x2**: Minibatch of data points with shape (batch_size, vector_dim)
Each item in x1 is a positive sample for the items with the same label in x2
That is, x1[0] and x2[0] form a positive pair iff label(x1[0]) = label(x2[0])
All data points in different rows should be decorrelated
Outputs:
- **loss**: loss tensor with shape (batch_size,).
"""
def __init__(self, smoothing_parameter=0.1, metric_loss_temp=0.1, batch_axis=0, **kwargs):
super(MultiNegativeCrossEntropyLoss, self).__init__(batch_axis, **kwargs)
self.cross_entropy_loss = nn.CrossEntropyLoss()
self.smoothing_parameter = smoothing_parameter # Smoothing probability mass
self.metric_loss_temp = metric_loss_temp
def _compute_distances(self, x1, x2):
"""
This function computes the cosine distance between every vector
in the two batches in input.
"""
# extracting sizes expecting [batch_size, dim]
assert x1.size() == x2.size()
# expanding both tensor form [batch_size, dim] to [batch_size, batch_size, dim]
x1_norm = torch.nn.functional.normalize(x1, p=2, dim=1)
x2_norm = torch.nn.functional.normalize(x2, p=2, dim=1)
return torch.mm(x1_norm, x2_norm.transpose(0, 1))
def _compute_labels(self, l1: torch.Tensor, l2: torch.Tensor):
"""
Example:
l1 = [1,2,2]
l2 = [1,2,1]
===>
[ [ 1, 0, 1],
[ 0, 1, 0],
[ 0, 1, 0] ]
This is an outer product with the equality predicate.
"""
batch_size = l1.size()[0]
l1_x = l1.unsqueeze(1).expand(batch_size, batch_size)
l2_x = l2.unsqueeze(0).expand(batch_size, batch_size)
#l1_x = F.broadcast_to(F.expand_dims(l1, 1), (batch_size, batch_size))
#l2_x = F.broadcast_to(F.expand_dims(l2, 0), (batch_size, batch_size))
ll = torch.eq(l1_x, l2_x)
labels = ll * (1 - self.smoothing_parameter) + (~ll) * self.smoothing_parameter / (batch_size - 1)
## now normalize rows to sum to 1.0
labels = labels / labels.sum(axis=1,keepdim=True).expand(batch_size, batch_size)
return labels
def _loss(self, x1: torch.Tensor, l1: torch.Tensor, x2: torch.Tensor, l2: torch.Tensor):
"""
the function computes the kl divergence between the negative distances
and the smoothed label matrix.
"""
labels = self._compute_labels(l1, l2)
distances = self._compute_distances(x1, x2) / self.metric_loss_temp
# multiply by the batch size to obtain the sum loss (kl_loss averages instead of sum)
return self.cross_entropy_loss(distances, labels.to(distances.device))
[docs] def forward(self, x1, l1, x2, l2):
return self._loss(x1, l1, x2, l2)