Skip to content

Add rewrites to lift/flatten Subtensors applied to IncSubtensors #1500

Open
@brandonwillard

Description

@brandonwillard

The following illustrates the missing rewrite/optimization:

import aesara
import aesara.tensor as at


# The proposed rewrite should also work when `A` is an array
A = 1
B = at.set_subtensor(at.zeros((11, 2))[:1], A)[:5]

aesara.dprint(B)
# Subtensor{:int64:} [id A]
#  |IncSubtensor{Set;:int64:} [id B]
#  | |Alloc [id C]
#  | | |TensorConstant{0.0} [id D]
#  | | |TensorConstant{11} [id E]
#  | | |TensorConstant{2} [id F]
#  | |TensorConstant{1} [id G]
#  | |ScalarConstant{1} [id H]
#  |ScalarConstant{5} [id I]

f_B = aesara.function([], B, mode="FAST_RUN")

# As we can see, no rewrites have been applied, so we're allocating an
# unnecessarily large array (i.e. with shape (11, 2) instead of (5, 2)):
aesara.dprint(f_B)
# Subtensor{:int64:} [id A] 2
#  |IncSubtensor{InplaceSet;:int64:} [id B] 1
#  | |Alloc [id C] 0
#  | | |TensorConstant{0.0} [id D]
#  | | |TensorConstant{11} [id E]
#  | | |TensorConstant{2} [id F]
#  | |TensorConstant{1} [id G]
#  | |ScalarConstant{1} [id H]
#  |ScalarConstant{5} [id I]

f_B()
# array([[1., 1.],
#        [0., 0.],
#        [0., 0.],
#        [0., 0.],
#        [0., 0.]])

Aside from being a generally good optimization to have, it would also simplify/obviate all the logic here in save_mem_new_scan—and perhaps other places as well.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions