diff --git a/segment_anything/automatic_mask_generator.py b/segment_anything/automatic_mask_generator.py
index d5a8c9692..b8f775eac 100644
--- a/segment_anything/automatic_mask_generator.py
+++ b/segment_anything/automatic_mask_generator.py
@@ -189,6 +189,7 @@ def generate(self, image: np.ndarray) -> List[Dict[str, Any]]:
                 "point_coords": [mask_data["points"][idx].tolist()],
                 "stability_score": mask_data["stability_score"][idx].item(),
                 "crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(),
+                "low_res": mask_data["low_res"][idx],
             }
             curr_anns.append(ann)
 
@@ -276,7 +277,7 @@ def _process_batch(
         transformed_points = self.predictor.transform.apply_coords(points, im_size)
         in_points = torch.as_tensor(transformed_points, device=self.predictor.device)
         in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device)
-        masks, iou_preds, _ = self.predictor.predict_torch(
+        masks, iou_preds, low_res = self.predictor.predict_torch(
             in_points[:, None, :],
             in_labels[:, None],
             multimask_output=True,
@@ -288,6 +289,7 @@ def _process_batch(
             masks=masks.flatten(0, 1),
             iou_preds=iou_preds.flatten(0, 1),
             points=torch.as_tensor(points.repeat(masks.shape[1], axis=0)),
+            low_res=low_res.flatten(0, 1),
         )
         del masks