Skip to content

Conversation

@AdrianGushin
Copy link
Collaborator

This pull request adds some fixes for the ShardLevel structure and adds associated test cases. It also updates the CPU struct to facilitate multiple distinct cores.

@AdrianGushin AdrianGushin changed the base branch from main to wma/shard_levels September 24, 2025 15:31
Copy link
Member

@willow-ahrens willow-ahrens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Almost there! This is looking great, only small changes requested.

Project.toml Outdated
DataStructures = "0.18"
Distributions = "0.25"
HDF5 = "0.17"
InteractiveUtils = "1.11.0"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

InteractiveUtils shouldn't be a Finch dep, we need to remove before merging.

Project.toml Outdated
NPZ = "15e1cf62-19b3-5cfa-8e77-841668bca605"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
TensorMarket = "8b7d4fe7-0b45-4d0d-9dd8-5cc9b23b4b77"
TensorMarket = "8b7d4fe7-0b45-4d0d-9dd8-5cc9b23b4b77"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tensor market should be a test dep, but I don't think it's an extra right? do we need to move it to the test project.toml?

@@ -1,8 +1,10 @@
using InteractiveUtils
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's remove this, it was just for debugging

A datatype representing a device on which tasks can be executed.
"""

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this like break the documentation for abstract device?

end,
)
VirtualCPU(value(n, Int))
VirtualCPU(value(n, Int), literal(id))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's also okay to just use id without wrapping as literal, as long as you wrap it when needed. whatever's convenient

FinchNotation.finch_leaf(mem::VirtualCPULocalMemory) = virtual(mem)
function virtualize(ctx, ex, ::Type{CPULocalMemory})
VirtualCPULocalMemory(virtualize(ctx, :($ex.device), CPU))
function virtualize(ctx, ex, ::Type{CPULocalMemory{id}}) where {id}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead of keying CPULocalMemory on id, let's put the whole CPU type in the type parameter so that further changes to CPU parameterization don't need to affect the localmem

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

then we can recursively virtualize the cpu

global_memory(device::CPU) = CPUSharedMemory(device)
local_memory(device::CPU{id}) where {id} = CPULocalMemory{id}(device)
shared_memory(device::CPU{id}) where {id} = CPUSharedMemory{id}(device)
global_memory(device::CPU{id}) where {id} = CPUSharedMemory{id}(device)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same thing here, I think it makes more sense to key on the CPU than the ID. feel free to disagree here.

function transfer(task::MemoryChannel, arr::MultiChannelBuffer)
if task.device == arr.device
temp = arr.data[task.t]
@assert isa(temp, Vector)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is good for debugging, but I don't think this will always be the case, we might have different buffer types than vector


@testset "Finch" begin
include("modules/checkoutput_testsetup.jl")
include("suites/constructors_tests.jl")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wow. It's crazy that this wasn't already included

end

@test C[4,4] == 12
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good test! Let's also add a test in perhaps representation.jl which generates some reference output for a shard level kernel

@willow-ahrens
Copy link
Member

We can merge this to main once we have some more tests.

Copy link
Member

@willow-ahrens willow-ahrens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi Adrian! This all looks good, the only requirement we need now is a test which calls check_output to compare the generated shardlevel code against a reference output. See elsewhere in the code where this function is used for examples. You'll need to follow instructions in the contributing guide to generate new 64-bit reference, and run the "fixbot" action to generate new 32-bit reference

@willow-ahrens
Copy link
Member

I've added you to the repo as a collaborator, you can re-open the PR using finch-tensor as the remote and it will automatically run tests.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants