Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/Finch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ export Dense, DenseLevel
export Element, ElementLevel
export AtomicElement, AtomicElementLevel
export Separate, SeparateLevel
export Shard, ShardLevel
export Mutex, MutexLevel
export Pattern, PatternLevel
export Scalar, SparseScalar, ShortCircuitScalar, SparseShortCircuitScalar
Expand All @@ -59,6 +60,7 @@ export fill_value, AsArray, expanddims, tensor_tree

export parallelAnalysis, ParallelAnalysisResults
export parallel, extent, auto
export Serial, SerialMemory, CPU, CPULocalArray, CPULocalMemory
export static_schedule, greedy_schedule, julia_schedule
export serial, cpu

Expand Down Expand Up @@ -143,6 +145,7 @@ include("tensors/levels/dense_rle_levels.jl")
include("tensors/levels/element_levels.jl")
include("tensors/levels/atomic_element_levels.jl")
include("tensors/levels/separate_levels.jl")
include("tensors/levels/shard_levels.jl")
include("tensors/levels/mutex_levels.jl")
include("tensors/levels/pattern_levels.jl")
include("tensors/masks.jl")
Expand Down
62 changes: 50 additions & 12 deletions src/architecture.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,14 @@ abstract type AbstractVirtualTask end
Return the number of tasks on the device dev.
"""
function get_num_tasks end

"""
get_task_num(task::AbstractTask)

Return the task number of `task`.
"""
function get_task_num end

"""
get_device(task::AbstractTask)

Expand All @@ -61,6 +63,25 @@ Return the task which spawned `task`.
"""
function get_parent_task end

get_num_tasks(ctx::AbstractCompiler) = get_num_tasks(get_task(ctx))
get_num_tasks(task::AbstractTask) = get_num_tasks(get_device(task))
get_task_num(ctx::AbstractCompiler) = get_task_num(get_task(ctx))
get_device(ctx::AbstractCompiler) = get_device(get_task(ctx))
get_parent_task(ctx::AbstractCompiler) = get_parent_task(get_task(ctx))

function is_on_device(ctx::AbstractCompiler, dev)
res = false
task = get_task(ctx)
while task != nothing
if get_device(task) == dev
res = true
break
end
task = get_parent_task(task)
end
return res
end

"""
aquire_lock!(dev::AbstractDevice, val)

Expand Down Expand Up @@ -92,21 +113,36 @@ function make_lock end
"""
Serial()

A device that represents a serial CPU execution.
A Task that represents a serial CPU execution.
"""
struct Serial <: AbstractTask end
struct Serial <: AbstractDevice end
serial() = Serial()
get_device(::Serial) = CPU(1)
get_parent_task(::Serial) = nothing
get_task_num(::Serial) = 1
get_num_tasks(::Serial) = 1
struct VirtualSerial <: AbstractVirtualTask end
virtualize(ctx, ex, ::Type{Serial}) = VirtualSerial()
lower(ctx::AbstractCompiler, task::VirtualSerial, ::DefaultStyle) = :(Finch.Serial())
virtual_call_def(ctx, alg, ::typeof(serial), Any) = VirtualSerial()
FinchNotation.finch_leaf(device::VirtualSerial) = virtual(device)
get_device(::VirtualSerial) = VirtualCPU(literal(1))
get_parent_task(::VirtualSerial) = nothing
get_task_num(::VirtualSerial) = literal(1)
get_num_tasks(::VirtualSerial) = literal(1)
Base.:(==)(::Serial, ::Serial) = true
Base.:(==)(::VirtualSerial, ::VirtualSerial) = true

"""
SerialTask()

A Task that represents a serial CPU execution.
"""
struct SerialTask <: AbstractDevice end
get_device(::SerialTask) = Serial()
get_parent_task(::SerialTask) = nothing
get_task_num(::SerialTask) = 1
struct VirtualSerialTask <: AbstractVirtualTask end
virtualize(ctx, ex, ::Type{SerialTask}) = VirtualSerialTask()
lower(ctx::AbstractCompiler, task::VirtualSerialTask, ::DefaultStyle) = :(SerialTask())
FinchNotation.finch_leaf(device::VirtualSerialTask) = virtual(device)
get_device(::VirtualSerialTask) = VirtualSerial()
get_parent_task(::VirtualSerialTask) = nothing
get_task_num(::VirtualSerialTask) = literal(1)

struct SerialMemory end
struct VirtualSerialMemory end
Expand Down Expand Up @@ -159,7 +195,9 @@ end
function lower(ctx::AbstractCompiler, device::VirtualCPU, ::DefaultStyle)
:(Finch.CPU($(ctx(device.n))))
end
get_num_tasks(device::VirtualCPU) = device.n
get_num_tasks(::VirtualCPU) = device.n
Base.:(==)(::CPU, ::CPU) = true
Base.:(==)(::VirtualCPU, ::VirtualCPU) = true #This is not strictly true. A better approach would name devices, and give them parents so that we can be sure to parallelize through the processor hierarchy.

FinchNotation.finch_leaf(device::VirtualCPU) = virtual(device)

Expand Down Expand Up @@ -220,11 +258,11 @@ Base.eltype(::Type{CPULocalArray{A}}) where {A} = eltype(A)
Base.ndims(::Type{CPULocalArray{A}}) where {A} = ndims(A)

transfer(device::Union{CPUThread,CPUSharedMemory}, arr::AbstractArray) = arr
function transfer(device::CPULocalMemory, arr::AbstractArray)
CPULocalArray{A}(mem.device, [copy(arr) for _ in 1:(mem.device.n)])
function transfer(mem::CPULocalMemory, arr::AbstractArray)
CPULocalArray{typeof(arr)}(mem.device, [copy(arr) for _ in 1:(mem.device.n)])
end
function transfer(task::CPUThread, arr::CPULocalArray)
if get_device(task) === arr.device
if get_device(task) == arr.device
temp = arr.data[task.tid]
return temp
else
Expand Down
2 changes: 1 addition & 1 deletion src/environment.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ variable names in the generated code of the executing environment.
namespace::Namespace = Namespace()
preamble::Vector{Any} = []
epilogue::Vector{Any} = []
task = VirtualSerial()
task = VirtualSerialTask()
end

"""
Expand Down
2 changes: 1 addition & 1 deletion src/tensors/levels/element_levels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ end

postype(::Type{<:ElementLevel{Vf,Tv,Tp}}) where {Vf,Tv,Tp} = Tp

function transfer(lvl::ElementLevel{Vf,Tv,Tp}, device, style) where {Vf,Tv,Tp}
function transfer(device, lvl::ElementLevel{Vf,Tv,Tp}) where {Vf,Tv,Tp}
return ElementLevel{Vf,Tv,Tp}(transfer(device, lvl.val))
end

Expand Down
2 changes: 1 addition & 1 deletion src/tensors/levels/separate_levels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ countstored_level(lvl::SeparateLevel, pos) = pos

mutable struct VirtualSeparateLevel <: AbstractVirtualLevel
tag
lvl # stand in for the sublevel for virutal resize, etc.
lvl # stand in for the sublevel for virtual resize, etc.
val
Tv
Lvl
Expand Down
Loading