Skip to content

Commit 3bb0a44

Browse files
authored
Merge pull request #25 from bit-bots/feature/all_encoders
Integrate all sensor encoders
2 parents a12c8f2 + 395a5c0 commit 3bb0a44

16 files changed

+490
-59
lines changed

ddlitlab2024/ml/model/encoder/action_history.py renamed to ddlitlab2024/ml/model/encoder/base.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,23 +4,23 @@
44
from ddlitlab2024.ml.model.misc import PositionalEncoding
55

66

7-
class ActionHistoryEncoder(nn.Module):
7+
class BaseEncoder(nn.Module):
88
"""
9-
Transformer encoder that encodes the action history of the robot.
9+
Transformer encoder that encodes a sequence of input vectors into context tokens.
1010
"""
1111

12-
def __init__(self, num_joints, hidden_dim, num_layers, num_heads, max_seq_len):
12+
def __init__(self, input_dim: int, hidden_dim: int, num_layers: int, num_heads: int, max_seq_len: int):
1313
"""
1414
Initializes the module.
1515
16-
:param num_joints: The number of joints in the robot.
16+
:param input_dim: The number of input dimensions.
1717
:param hidden_dim: The number of hidden dimensions.
1818
:param num_layers: The number of transformer layers.
1919
:param num_heads: The number of attention heads.
2020
:param max_seq_len: The maximum length of the input sequences (used for positional encoding
2121
"""
2222
super().__init__()
23-
self.embedding = nn.Linear(num_joints, hidden_dim)
23+
self.embedding = nn.Linear(input_dim, hidden_dim)
2424
self.positional_encoding = PositionalEncoding(hidden_dim, max_seq_len)
2525
self.transformer_encoder = nn.TransformerEncoder(
2626
nn.TransformerEncoderLayer(
@@ -34,14 +34,13 @@ def __init__(self, num_joints, hidden_dim, num_layers, num_heads, max_seq_len):
3434
num_layers=num_layers,
3535
)
3636

37-
def forward(self, past_actions: torch.Tensor) -> torch.Tensor:
37+
def forward(self, x: torch.Tensor) -> torch.Tensor:
3838
"""
39-
Encodes the past actions of the robot as context tokens.
39+
Encodes the input vectors into context tokens.
4040
41-
:param past_actions: The past actions of the robot. Shape: (batch_size, seq_len, joint)
41+
:param past_actions: The input vectors. Shape: (batch_size, seq_len, input_dim)
4242
:return: The encoded context tokens. Shape: (batch_size, seq_len, hidden_dim)
4343
"""
44-
x = past_actions
4544
# Embed the input
4645
x = self.embedding(x)
4746
# Positional encoding
Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
from enum import Enum
2+
3+
import torch
4+
from torch import nn
5+
from torchvision.models import resnet18, resnet50, swin_s, swin_t
6+
7+
from ddlitlab2024.ml.model.encoder.base import BaseEncoder
8+
9+
10+
class ImageEncoderType(Enum):
11+
"""
12+
Enum class for the image encoder types.
13+
"""
14+
15+
RESNET18 = "resnet18"
16+
RESNET50 = "resnet50"
17+
SWIN_TRANSFORMER_TINY = "swin_transformer_tiny"
18+
SWIN_TRANSFORMER_SMALL = "swin_transformer_small"
19+
20+
21+
class SequenceEncoderType(Enum):
22+
"""
23+
Enum class for the sequence encoder types.
24+
"""
25+
26+
TRANSFORMER = "transformer"
27+
NONE = "none"
28+
29+
30+
class AbstractImageEncoder(nn.Module):
31+
"""
32+
Abstract class for image encoders.
33+
"""
34+
35+
encoder: nn.Module
36+
37+
def forward(self, x: torch.Tensor) -> torch.Tensor:
38+
"""
39+
Forward pass of the image encoder.
40+
41+
:param x: A sequence of images.
42+
:return: A sequence of encoded images.
43+
"""
44+
# Squash the sequence dimension together with the batch dimension
45+
images = x.view(-1, *x.shape[2:])
46+
47+
# Encode the images into tokens
48+
tokens = self.encoder(images)
49+
50+
# Restore the original sequence dimension
51+
return tokens.view(x.shape[0], x.shape[1], -1)
52+
53+
54+
class ResNetImageEncoder(AbstractImageEncoder):
55+
"""
56+
ResNet image encoder.
57+
"""
58+
59+
def __init__(self, resnet_type: ImageEncoderType, hidden_dim: int):
60+
super().__init__()
61+
match resnet_type:
62+
case ImageEncoderType.RESNET18:
63+
self.encoder = resnet18(pretrained=True)
64+
case ImageEncoderType.RESNET50:
65+
self.encoder = resnet50(pretrained=True)
66+
case _:
67+
raise ValueError(f"Invalid ResNet type: {resnet_type}")
68+
# TODO check for softmax layer etc.
69+
self.encoder.fc = nn.Linear(self.encoder.fc.in_features, hidden_dim)
70+
71+
72+
class SwinTransformerImageEncoder(AbstractImageEncoder):
73+
"""
74+
Swin Transformer image encoder.
75+
"""
76+
77+
def __init__(self, swin_type: ImageEncoderType, hidden_dim: int):
78+
super().__init__()
79+
match swin_type:
80+
case ImageEncoderType.SWIN_TRANSFORMER_TINY:
81+
self.encoder = swin_t()
82+
case ImageEncoderType.SWIN_TRANSFORMER_SMALL:
83+
self.encoder = swin_s()
84+
case _:
85+
raise ValueError(f"Invalid Swin Transformer type: {swin_type}")
86+
self.encoder.head = nn.Linear(self.encoder.head.in_features, hidden_dim)
87+
88+
89+
class TransformerImageSequenceEncoder(nn.Module):
90+
"""
91+
Transformer image sequence encoder.
92+
"""
93+
94+
def __init__(self, image_encoder: AbstractImageEncoder, hidden_dim: int, num_layers: int, max_seq_len: int):
95+
super().__init__()
96+
self.image_encoder = image_encoder
97+
self.transformer_encoder = BaseEncoder(
98+
input_dim=hidden_dim, hidden_dim=hidden_dim, num_layers=num_layers, num_heads=8, max_seq_len=max_seq_len
99+
)
100+
101+
def forward(self, x: torch.Tensor) -> torch.Tensor:
102+
return self.transformer_encoder(self.image_encoder(x))
103+
104+
105+
def image_encoder_factory(encoder_type: ImageEncoderType, hidden_dim: int) -> AbstractImageEncoder:
106+
"""
107+
Factory function for creating image encoders.
108+
109+
:param encoder_type: The type of the image encoder.
110+
:return: The image encoder.
111+
"""
112+
if encoder_type in [ImageEncoderType.RESNET18, ImageEncoderType.RESNET50]:
113+
return ResNetImageEncoder(encoder_type, hidden_dim)
114+
if encoder_type in [ImageEncoderType.SWIN_TRANSFORMER_TINY, ImageEncoderType.SWIN_TRANSFORMER_SMALL]:
115+
return SwinTransformerImageEncoder(encoder_type, hidden_dim)
116+
else:
117+
raise ValueError(f"Invalid image encoder type: {encoder_type}")
118+
119+
120+
def image_sequence_encoder_factory(
121+
encoder_type: SequenceEncoderType,
122+
image_encoder_type: ImageEncoderType,
123+
hidden_dim: int,
124+
num_layers: int,
125+
max_seq_len: int,
126+
):
127+
"""
128+
Factory function for creating image sequence encoders.
129+
130+
:param encoder_type: The type of the sequence encoder that allows communication between different images.
131+
If no sequence encoder is needed, the image encoder is returned.
132+
:param image_encoder_type: The type of the image encoder.
133+
:return: The image sequence encoder.
134+
"""
135+
image_encoder = image_encoder_factory(image_encoder_type, hidden_dim)
136+
137+
match encoder_type:
138+
case SequenceEncoderType.TRANSFORMER:
139+
return TransformerImageSequenceEncoder(image_encoder, hidden_dim, num_layers, max_seq_len)
140+
case SequenceEncoderType.NONE:
141+
return image_encoder
142+
case _:
143+
raise ValueError(f"Invalid sequence encoder type: {encoder_type}")

ddlitlab2024/ml/model/encoder/imu.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
from enum import Enum
2+
3+
from ddlitlab2024.ml.model.encoder.base import BaseEncoder
4+
5+
6+
class IMUEncoder(BaseEncoder):
7+
"""
8+
Transformer encoder that encodes the action history of the robot.
9+
"""
10+
11+
class OrientationEmbeddingMethod(Enum):
12+
"""
13+
Enum class for the orientation embedding methods.
14+
"""
15+
16+
QUATERNION = "quaternion"
17+
FIVE_DIM = "five_dim" # Axis-angle with 2d vector for the angle
18+
19+
def __init__(
20+
self,
21+
orientation_embedding_method: OrientationEmbeddingMethod,
22+
hidden_dim: int,
23+
num_layers: int,
24+
num_heads: int,
25+
max_seq_len: int,
26+
):
27+
"""
28+
Initializes the module.
29+
30+
:param orientation_embedding_method: The method used to embed the orientation data.
31+
:param hidden_dim: The number of hidden dimensions.
32+
:param num_layers: The number of transformer layers.
33+
:param num_heads: The number of attention heads.
34+
:param max_seq_len: The maximum length of the input sequences (used for positional encoding
35+
"""
36+
37+
# Calculate the number of input features
38+
match orientation_embedding_method:
39+
case IMUEncoder.OrientationEmbeddingMethod.QUATERNION:
40+
input_features = 4
41+
case IMUEncoder.OrientationEmbeddingMethod.FIVE_DIM:
42+
input_features = 5
43+
44+
super().__init__(
45+
input_dim=input_features,
46+
hidden_dim=hidden_dim,
47+
num_layers=num_layers,
48+
num_heads=num_heads,
49+
max_seq_len=max_seq_len,
50+
)
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
from ddlitlab2024.ml.model.encoder.base import BaseEncoder
2+
3+
4+
class JointEncoder(BaseEncoder):
5+
"""
6+
Joint encoder that encodes the joint states of the robot.
7+
"""
8+
9+
def __init__(self, num_joints: int, hidden_dim: int, num_layers: int, num_heads: int, max_seq_len: int):
10+
"""
11+
Initializes the module.
12+
13+
:param num_joints: The number of joints in the robot.
14+
:param hidden_dim: The number of hidden dimensions.
15+
:param num_layers: The number of transformer layers.
16+
:param num_heads: The number of attention heads.
17+
:param max_seq_len: The maximum length of the input sequences (used for positional encoding
18+
"""
19+
super().__init__(
20+
input_dim=num_joints,
21+
hidden_dim=hidden_dim,
22+
num_layers=num_layers,
23+
num_heads=num_heads,
24+
max_seq_len=max_seq_len,
25+
)

ddlitlab2024/ml/model/encoder/joint_states.py

Lines changed: 0 additions & 9 deletions
This file was deleted.

0 commit comments

Comments
 (0)