Use KL divergence etc as described in e.g. https://github.com/tung-nd/E2C-pytorch/blob/master/train_e2c.py#L31