Skip to content

Diffusion approach #6

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
1 of 4 tasks
moritzschaefer opened this issue Aug 1, 2024 · 2 comments · May be fixed by #7
Open
1 of 4 tasks

Diffusion approach #6

moritzschaefer opened this issue Aug 1, 2024 · 2 comments · May be fixed by #7

Comments

@moritzschaefer
Copy link
Collaborator

moritzschaefer commented Aug 1, 2024

  • Generation works

Next steps:

  • train on (217, 20) (see below. need to adjust the data preprocessing to return bigger chunks (32->224, 256->1792)
  • condition on a meaningful embedding (for now: fourier)
  • make an example submission (just predict the first 208 datapoints and repeat-pad the rest)
  • if it works well score-wise, predict the second half independently and interpolate

Completed 46/72. Loaded data: (1730, 8) - padded to: (1792, 8) - predictions (224, 20) - downsampled to: (217, 20)
Completed 47/72. Loaded data: (1730, 8) - padded to: (1792, 8) - predictions (224, 20) - downsampled to: (217, 20)
Completed 48/72. Loaded data: (1732, 8) - padded to: (1792, 8) - predictions (224, 20) - downsampled to: (217, 20)
Completed 49/72. Loaded data: (1731, 8) - padded to: (1792, 8) - predictions (224, 20) - downsampled to: (217, 20)
Completed 50/72. Loaded data: (1735, 8) - padded to: (1792, 8) - predictions (224, 20) - downsampled to: (217, 20)
Completed 51/72. Loaded data: (1730, 8) - padded to: (1792, 8) - predictions (224, 20) - downsampled to: (217, 20)
Completed 52/72. Loaded data: (1729, 8) - padded to: (1792, 8) - predictions (224, 20) - downsampled to: (217, 20)
Completed 53/72. Loaded data: (1730, 8) - padded to: (1792, 8) - predictions (224, 20) - downsampled to: (217, 20)
Completed 54/72. Loaded data: (1730, 8) - padded to: (1792, 8) - predictions (224, 20) - downsampled to: (217, 20)
Completed 55/72. Loaded data: (1729, 8) - padded to: (1792, 8) - predictions (224, 20) - downsampled to: (217, 20)
Completed 56/72. Loaded data: (1734, 8) - padded to: (1792, 8) - predictions (224, 20) - downsampled to: (217, 20)
Completed 57/72. Loaded data: (1733, 8) - padded to: (1792, 8) - predictions (224, 20) - downsampled to: (217, 20)
Completed 58/72. Loaded data: (1729, 8) - padded to: (1792, 8) - predictions (224, 20) - downsampled to: (217, 20)
Completed 59/72. Loaded data: (1731, 8) - padded to: (1792, 8) - predictions (224, 20) - downsampled to: (217, 20)
Completed 60/72. Loaded data: (1731, 8) - padded to: (1792, 8) - predictions (224, 20) - downsampled to: (217, 20)
Completed 61/72. Loaded data: (1732, 8) - padded to: (1792, 8) - predictions (224, 20) - downsampled to: (217, 20)
Completed 62/72. Loaded data: (1736, 8) - padded to: (1792, 8) - predictions (224, 20) - downsampled to: (217, 20)
Completed 63/72. Loaded data: (1731, 8) - padded to: (1792, 8) - predictions (224, 20) - downsampled to: (217, 20)
Completed 64/72. Loaded data: (1732, 8) - padded to: (1792, 8) - predictions (224, 20) - downsampled to: (217, 20)
Completed 65/72. Loaded data: (1733, 8) - padded to: (1792, 8) - predictions (224, 20) - downsampled to: (217, 20)
Completed 66/72. Loaded data: (1734, 8) - padded to: (1792, 8) - predictions (224, 20) - downsampled to: (217, 20)
Completed 67/72. Loaded data: (1729, 8) - padded to: (1792, 8) - predictions (224, 20) - downsampled to: (217, 20)
Completed 68/72. Loaded data: (1727, 8) - padded to: (1792, 8) - predictions (224, 20) - downsampled to: (216, 20)
Completed 69/72. Loaded data: (1736, 8) - padded to: (1792, 8) - predictions (224, 20) - downsampled to: (217, 20)
Completed 70/72. Loaded data: (1729, 8) - padded to: (1792, 8) - predictions (224, 20) - downsampled to: (217, 20)
Completed 71/72. Loaded data: (1735, 8) - padded to: (1792, 8) - predictions (224, 20) - downsampled to: (217, 20)
Completed 72/72. Loaded data: (1735, 8) - padded to: (1792, 8) - predictions (224, 20) - downsampled to: (217, 20)

@moritzschaefer moritzschaefer linked a pull request Aug 1, 2024 that will close this issue
@moritzschaefer
Copy link
Collaborator Author

moritzschaefer commented Aug 1, 2024

Submission code (same as in notebook), not yet ready for the diffusion model (need to adapt the evaluate function)

pred_list = []

# loop over each trial
for i, p in enumerate(all_paths):
    # get EMG data 
    sample = np.load(p)
    myo = sample['data_myo']
    myo = myo[:, LEFT_TO_RIGHT_HAND]

    # predictions will have to be downsampled
    gt_len = myo[::8].shape[0]

    # padding
    target_length = (myo.shape[0] + 255) // 256 * 256
    padded_myo = np.pad(myo, ((0, target_length - myo.shape[0]), (0, 0)), mode='constant', constant_values=0)
    # padded_myo = preprocess_batch(torch.from_numpy(padded_myo).unsqueeze(0)).to(device)  # neede for VRNN

    # some prediction. might be slididng window.
    # preds = torch.stack(model.forward(padded_myo)[-1])
    preds = model.inference(padded_myo)
    preds_downsampled = preds[:gt_len]
    print(f"Completed {i+1}/{len(all_paths)}. Loaded data: {myo.shape} - padded to: {padded_myo.shape} - predictions {preds.shape} - downsampled to: {preds_downsampled.shape}")
    # pred_list.append(preds_downsampled.permute(1, 0, 2).detach().cpu().numpy())
    pred_list.append(preds_downsampled)

@moritzschaefer
Copy link
Collaborator Author

Messages I sent to @shemlem in Discord:

It is conditioned on windowed fourier features 🙂.

If you have some time tomorrow, would be cool if you can run the script (train_diffusion on the server, changing a couple of hyperparameters (especially the number of fourier features we extract and maybe testing a smaller window_size (e.g. the original 256)) to see whether evaluate spits out smaller losses.
In the very first run, I got MSE 0.1 on 16 samples of the test set already. This looks really good in my opinion (and also visually the trajectories are much better)
If you manage to optimize some of the hyperparameters, we could win this..!!
note that you would need to install these two libraries with conda:

diffusers=0.29.2
accelerate=0.33.0
moritzs — heute um 22:20 Uhr
Three more things:
Decreasing the window_size (currently 1664) should make things much easier for the model, since it only needs to predict one movement. When looking at the kinematics labels of a full recorded track, it seems like participants made hand-gestures in fixed time intervals (e.g. all 10 seconds). Could you check whether this is indeed the case? If so, we can train a diffusion model only on the small window size and make inference in a windowed multi-step fashion.
Another option is to train the diffusion model both on small 256-sized windows and on large 1644-sized windows. This is possible since the conditioning embedding (fourier features) can have flexible lengths (through cross atteniton). Yet I don't know how much the model benefits from this learning
As a final step, when converged on a model architecture+hyperparameters, we should train the model with both train+test data to increase data size!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant