Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

scatter reduce lowering with include_self=False #3153

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

apbose
Copy link
Collaborator

@apbose apbose commented Sep 11, 2024

This is for scatter_reduce decomposition where include_self=False

@github-actions github-actions bot added component: tests Issues re: Tests component: lowering Issues re: The lowering / preprocessing passes component: api [Python] Issues re: Python API component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths labels Sep 11, 2024
@github-actions github-actions bot requested a review from gs-olive September 11, 2024 04:26
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

@apbose apbose force-pushed the scatter_reduce_decomposition_include_self branch from 9a8c124 to 645a725 Compare September 11, 2024 04:31
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

@apbose apbose force-pushed the scatter_reduce_decomposition_include_self branch from 645a725 to 0575e51 Compare September 11, 2024 04:39
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/_decompositions.py	2024-09-11 04:39:39.095424+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/_decompositions.py	2024-09-11 04:39:59.142801+00:00
@@ -302,24 +302,32 @@
        obj.description = description
        obj.func = func
        return obj

    def reduce_operation_with_scatter_include_self(
-        self, operation_lhs, initial_tensor, dim, index_tensor, src_tensor,  min_ele = float('-inf'), max_ele = float('inf'), include_self=True
+        self,
+        operation_lhs,
+        initial_tensor,
+        dim,
+        index_tensor,
+        src_tensor,
+        min_ele=float("-inf"),
+        max_ele=float("inf"),
+        include_self=True,
    ):
        scatter_tensor = None
        if self == ReduceOperation.SUM or self == ReduceOperation.MEAN:
            scatter_tensor = torch.zeros_like(initial_tensor)
        elif self == ReduceOperation.PROD:
            scatter_tensor = torch.ones_like(initial_tensor)
        elif self == ReduceOperation.AMAX:
            scatter_tensor = initial_tensor
-            if(not(include_self)):
+            if not (include_self):
                scatter_tensor = torch.full_like(initial_tensor, min_ele)
        elif self == ReduceOperation.AMIN:
            scatter_tensor = initial_tensor
-            if(not(include_self)):
+            if not (include_self):
                scatter_tensor = torch.full_like(initial_tensor, max_ele)
        else:
            # This case would not be encountered from torch itself
            print("Invalid Operation for Reduce op!!")

@@ -342,27 +350,35 @@
    include_self: bool = True,
) -> torch.Tensor:
    scatter_loop_tensor = input_tensor
    MAX_ELE = 0
    MIN_ELE = 0
-    if(src_tensor.dtype == torch.int32 or input_tensor.dtype == torch.int32):
+    if src_tensor.dtype == torch.int32 or input_tensor.dtype == torch.int32:
        MAX_ELE = 2147483647
        MIN_ELE = -2147483648
    else:
-        MAX_ELE = float('inf')
-        MIN_ELE = float('-inf')
-    if(not(include_self)):
-        if (reduce == "sum" or reduce == "mean"):
-            scatter_loop_tensor = torch.scatter(scatter_loop_tensor, dim, index, torch.zeros_like(src_tensor))
-        if (reduce == "prod"):
-            scatter_loop_tensor = torch.scatter(scatter_loop_tensor, dim, index, torch.ones_like(src_tensor))
-        if (reduce == "amax"):
+        MAX_ELE = float("inf")
+        MIN_ELE = float("-inf")
+    if not (include_self):
+        if reduce == "sum" or reduce == "mean":
+            scatter_loop_tensor = torch.scatter(
+                scatter_loop_tensor, dim, index, torch.zeros_like(src_tensor)
+            )
+        if reduce == "prod":
+            scatter_loop_tensor = torch.scatter(
+                scatter_loop_tensor, dim, index, torch.ones_like(src_tensor)
+            )
+        if reduce == "amax":
            src_red_tensor = torch.full_like(src_tensor, MIN_ELE)
-            scatter_loop_tensor = torch.scatter(scatter_loop_tensor, dim, index, src_red_tensor)
-        if (reduce == "amin"):
+            scatter_loop_tensor = torch.scatter(
+                scatter_loop_tensor, dim, index, src_red_tensor
+            )
+        if reduce == "amin":
            src_red_tensor = torch.full_like(src_tensor, MAX_ELE)
-            scatter_loop_tensor = torch.scatter(scatter_loop_tensor, dim, index, src_red_tensor)
+            scatter_loop_tensor = torch.scatter(
+                scatter_loop_tensor, dim, index, src_red_tensor
+            )

    device_input_tensor = input_tensor.device
    # required for mean reduce operation
    scatter_count_tensor = torch.zeros_like(input_tensor)
    src_shape = list(src_tensor.shape)
@@ -390,34 +406,55 @@
                dim,
                index_slice,
                torch.ones_like(src_slice),
                MIN_ELE,
                MAX_ELE,
-                include_self
+                include_self,
            )
        elif reduce == "amax":
            reduceOp = ReduceOperation.AMAX
        elif reduce == "amin":
            reduceOp = ReduceOperation.AMIN
        scatter_loop_tensor = reduceOp.reduce_operation_with_scatter_include_self(
-            scatter_loop_tensor, input_tensor, dim, index_slice, src_slice, MIN_ELE, MAX_ELE, include_self
+            scatter_loop_tensor,
+            input_tensor,
+            dim,
+            index_slice,
+            src_slice,
+            MIN_ELE,
+            MAX_ELE,
+            include_self,
        )
    if reduce == "mean":
        scatter_loop_tensor = torch.div(
            scatter_loop_tensor,
-            torch.add(scatter_count_tensor, torch.ones_like(scatter_count_tensor))  if include_self else scatter_count_tensor,
+            (
+                torch.add(scatter_count_tensor, torch.ones_like(scatter_count_tensor))
+                if include_self
+                else scatter_count_tensor
+            ),
            rounding_mode="trunc",
        )
-    #for include_self cases for amax and amin additional processing is required
-    #except for the max elements in amax, rest are -inf or INT_MIN
-    #except for the min elements in amin, rest are +inf or INT_MAX
-    if reduce == "amax" and not(include_self):
-        #the relevant should be min, rest original
-        return torch.max(scatter_loop_tensor, torch.scatter(input_tensor, dim, index, torch.full_like(src_tensor, MIN_ELE)))
-    if reduce == "amin" and not(include_self):
-        #the relevant should be min, rest original
-        return torch.min(scatter_loop_tensor, torch.scatter(input_tensor, dim, index, torch.full_like(src_tensor, MAX_ELE)))
+    # for include_self cases for amax and amin additional processing is required
+    # except for the max elements in amax, rest are -inf or INT_MIN
+    # except for the min elements in amin, rest are +inf or INT_MAX
+    if reduce == "amax" and not (include_self):
+        # the relevant should be min, rest original
+        return torch.max(
+            scatter_loop_tensor,
+            torch.scatter(
+                input_tensor, dim, index, torch.full_like(src_tensor, MIN_ELE)
+            ),
+        )
+    if reduce == "amin" and not (include_self):
+        # the relevant should be min, rest original
+        return torch.min(
+            scatter_loop_tensor,
+            torch.scatter(
+                input_tensor, dim, index, torch.full_like(src_tensor, MAX_ELE)
+            ),
+        )
    return scatter_loop_tensor


def get_decompositions(
    enable_experimental_decompositions: bool = False,

@apbose apbose force-pushed the scatter_reduce_decomposition_include_self branch from 0575e51 to 47fec01 Compare September 11, 2024 04:50
@apbose apbose force-pushed the scatter_reduce_decomposition_include_self branch from 47fec01 to 8900f67 Compare October 15, 2024 04:43
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed component: api [Python] Issues re: Python API component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths component: lowering Issues re: The lowering / preprocessing passes component: tests Issues re: Tests
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants