Skip to content

Conversation

@ruiheng123
Copy link
Contributor

@ruiheng123 ruiheng123 commented Jul 31, 2024

@puyuan1996 puyuan1996 added the enhancement New feature or request label Aug 5, 2024

# Initialize the total loss tensor on the correct device
self.loss_total = torch.tensor(0., device=device)
for k, v in kwargs.items():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

# Define a dictionary for loss weights and harmony_s variables
loss_weights = {
    'loss_obs': (self.obs_loss_weight, 'loss_obs_harmony_s'),
    'loss_rewards': (self.reward_loss_weight, 'loss_rewards_harmony_s'),
    'loss_policy': (self.policy_loss_weight, 'loss_policy_harmony_s'),
    'loss_value': (self.value_loss_weight, 'loss_value_harmony_s'),
    'loss_ends': (self.ends_loss_weight, 'loss_ends_harmony_s'),
    'latent_recon_loss': (self.latent_recon_loss_weight, 'latent_recon_loss_harmony_s'),
    'perceptual_loss': (self.perceptual_loss_weight, 'perceptual_loss_harmony_s')
}

# Iterate through kwargs to process the losses
for k, v in kwargs.items():
    if k in loss_weights:
        weight, harmony_var_name = loss_weights[k]
        harmony_s = globals().get(harmony_var_name)  # Get the harmony_s variable by name

        if harmony_s_dict is None:
            self.loss_total += weight * v
        elif harmony_s is not None:
            self.loss_total += (v / torch.exp(harmony_s)) + torch.log(torch.exp(harmony_s) + 1)
        else:
            self.loss_total += weight * v

)


# else:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

删除

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants