tmnt.distribution

Variational latent distributions (e.g. Gaussian, Logistic Gaussian)

Classes

BaseDistribution(enc_size, n_latent, device)
GaussianDistribution(enc_size, n_latent[, ...]) Gaussian latent distribution with diagnol co-variance.
GaussianUnitVarDistribution(n_latent[, ...]) Gaussian latent distribution with fixed unit variance.
LogisticGaussianDistribution(enc_size, n_latent) Logistic normal/Gaussian latent distribution with specified prior
Projection(enc_size, n_latent[, device])
VonMisesDistribution(enc_size, n_latent[, ...])
class BaseDistribution(enc_size, n_latent, device, on_simplex=False)[source]

Bases: Module

class GaussianDistribution(enc_size, n_latent, device='cpu', dr=0.2)[source]

Bases: BaseDistribution

Gaussian latent distribution with diagnol co-variance.

Parameters:
  • n_latent (int) – Dimentionality of the latent distribution
  • device (device) – Torch computational context (cpu or gpu[id])
  • dr (float) – Dropout value for dropout applied post sample. optional (default = 0.2)
forward(data, batch_size)[source]

Generate a sample according to the Gaussian given the encoder outputs

get_mu_encoding(data, include_bn=True, normalize=False)[source]

Provide the distribution mean as the natural result of running the full encoder

Parameters:data (mxnet.ndarray.NDArray) – Output of pre-latent encoding layers
Returns:Encoding vector representing unnormalized topic proportions
Return type:encoding (mxnet.ndarray.NDArray)
class GaussianUnitVarDistribution(n_latent, device='cpu', dr=0.2, var=1.0)[source]

Bases: BaseDistribution

Gaussian latent distribution with fixed unit variance.

Parameters:
  • n_latent (int) – Dimentionality of the latent distribution
  • device (device) – Torch computational context (cpu or gpu[id])
  • dr (float) – Dropout value for dropout applied post sample. optional (default = 0.2)
forward(data, batch_size)[source]

Generate a sample according to the unit variance Gaussian given the encoder outputs

get_mu_encoding(data, include_bn=True, normalize=False)[source]

Provide the distribution mean as the natural result of running the full encoder

Parameters:data (mxnet.ndarray.NDArray) – Output of pre-latent encoding layers
Returns:Encoding vector representing unnormalized topic proportions
Return type:encoding (mxnet.ndarray.NDArray)
class LogisticGaussianDistribution(enc_size, n_latent, device='cpu', dr=0.1, alpha=1.0)[source]

Bases: BaseDistribution

Logistic normal/Gaussian latent distribution with specified prior

Parameters:
  • n_latent (int) – Dimentionality of the latent distribution
  • device (device) – Torch computational context (cpu or gpu[id])
  • dr (float) – Dropout value for dropout applied post sample. optional (default = 0.2)
  • alpha (float) – Value the determines prior variance as 1/alpha - (2/n_latent) + 1/(n_latent^2)
forward(data, batch_size)[source]

Generate a sample according to the logistic Gaussian latent distribution given the encoder outputs

get_mu_encoding(data, include_bn=True, normalize=False)[source]

Provide the distribution mean as the natural result of running the full encoder

Parameters:data (mxnet.ndarray.NDArray) – Output of pre-latent encoding layers
Returns:Encoding vector representing unnormalized topic proportions
Return type:encoding (mxnet.ndarray.NDArray)
class VonMisesDistribution(enc_size, n_latent, kappa=100.0, dr=0.1, device='cpu')[source]

Bases: BaseDistribution

forward(data, batch_size)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

get_mu_encoding(data, include_bn=True, normalize=False)[source]

Provide the distribution mean as the natural result of running the full encoder

Parameters:data (mxnet.ndarray.NDArray) – Output of pre-latent encoding layers
Returns:Encoding vector representing unnormalized topic proportions
Return type:encoding (mxnet.ndarray.NDArray)