Skip to content
Draft
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/Finch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ export Dense, DenseLevel
export Element, ElementLevel
export AtomicElement, AtomicElementLevel
export Separate, SeparateLevel
export Shard, ShardLevel
export Mutex, MutexLevel
export Pattern, PatternLevel
export Scalar, SparseScalar, ShortCircuitScalar, SparseShortCircuitScalar
Expand All @@ -59,6 +60,7 @@ export fill_value, AsArray, expanddims, tensor_tree

export parallelAnalysis, ParallelAnalysisResults
export parallel, extent, auto
export Serial, SerialMemory, CPU, CPULocalArray, CPULocalMemory
export static_schedule, greedy_schedule, julia_schedule
export serial, cpu

Expand Down Expand Up @@ -143,6 +145,7 @@ include("tensors/levels/dense_rle_levels.jl")
include("tensors/levels/element_levels.jl")
include("tensors/levels/atomic_element_levels.jl")
include("tensors/levels/separate_levels.jl")
include("tensors/levels/shard_levels.jl")
include("tensors/levels/mutex_levels.jl")
include("tensors/levels/pattern_levels.jl")
include("tensors/masks.jl")
Expand Down
62 changes: 50 additions & 12 deletions src/architecture.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,14 @@ abstract type AbstractVirtualTask end
Return the number of tasks on the device dev.
"""
function get_num_tasks end

"""
get_task_num(task::AbstractTask)

Return the task number of `task`.
"""
function get_task_num end

"""
get_device(task::AbstractTask)

Expand All @@ -61,6 +63,25 @@ Return the task which spawned `task`.
"""
function get_parent_task end

get_num_tasks(ctx::AbstractCompiler) = get_num_tasks(get_task(ctx))
get_num_tasks(task::AbstractTask) = get_num_tasks(get_device(task))
get_task_num(ctx::AbstractCompiler) = get_task_num(get_task(ctx))
get_device(ctx::AbstractCompiler) = get_device(get_task(ctx))
get_parent_task(ctx::AbstractCompiler) = get_parent_task(get_task(ctx))

function is_on_device(ctx::AbstractCompiler, dev)
res = false
task = get_task(ctx)
while task != nothing
if get_device(task) == dev
res = true
break
end
task = get_parent_task(task)
end
return res
end

"""
aquire_lock!(dev::AbstractDevice, val)

Expand Down Expand Up @@ -92,21 +113,36 @@ function make_lock end
"""
Serial()

A device that represents a serial CPU execution.
A Task that represents a serial CPU execution.
"""
struct Serial <: AbstractTask end
struct Serial <: AbstractDevice end
serial() = Serial()
get_device(::Serial) = CPU(1)
get_parent_task(::Serial) = nothing
get_task_num(::Serial) = 1
get_num_tasks(::Serial) = 1
struct VirtualSerial <: AbstractVirtualTask end
virtualize(ctx, ex, ::Type{Serial}) = VirtualSerial()
lower(ctx::AbstractCompiler, task::VirtualSerial, ::DefaultStyle) = :(Finch.Serial())
virtual_call_def(ctx, alg, ::typeof(serial), Any) = VirtualSerial()
FinchNotation.finch_leaf(device::VirtualSerial) = virtual(device)
get_device(::VirtualSerial) = VirtualCPU(literal(1))
get_parent_task(::VirtualSerial) = nothing
get_task_num(::VirtualSerial) = literal(1)
get_num_tasks(::VirtualSerial) = literal(1)
Base.:(==)(::Serial, ::Serial) = true
Base.:(==)(::VirtualSerial, ::VirtualSerial) = true

"""
SerialTask()

A Task that represents a serial CPU execution.
"""
struct SerialTask <: AbstractDevice end
get_device(::SerialTask) = Serial()
get_parent_task(::SerialTask) = nothing
get_task_num(::SerialTask) = 1
struct VirtualSerialTask <: AbstractVirtualTask end
virtualize(ctx, ex, ::Type{SerialTask}) = VirtualSerialTask()
lower(ctx::AbstractCompiler, task::VirtualSerialTask, ::DefaultStyle) = :(SerialTask())
FinchNotation.finch_leaf(device::VirtualSerialTask) = virtual(device)
get_device(::VirtualSerialTask) = VirtualSerial()
get_parent_task(::VirtualSerialTask) = nothing
get_task_num(::VirtualSerialTask) = literal(1)

struct SerialMemory end
struct VirtualSerialMemory end
Expand Down Expand Up @@ -159,7 +195,9 @@ end
function lower(ctx::AbstractCompiler, device::VirtualCPU, ::DefaultStyle)
:(Finch.CPU($(ctx(device.n))))
end
get_num_tasks(device::VirtualCPU) = device.n
get_num_tasks(::VirtualCPU) = device.n
Base.:(==)(::CPU, ::CPU) = true
Base.:(==)(::VirtualCPU, ::VirtualCPU) = true #This is not strictly true. A better approach would name devices, and give them parents so that we can be sure to parallelize through the processor hierarchy.

FinchNotation.finch_leaf(device::VirtualCPU) = virtual(device)

Expand Down Expand Up @@ -220,11 +258,11 @@ Base.eltype(::Type{CPULocalArray{A}}) where {A} = eltype(A)
Base.ndims(::Type{CPULocalArray{A}}) where {A} = ndims(A)

transfer(device::Union{CPUThread,CPUSharedMemory}, arr::AbstractArray) = arr
function transfer(device::CPULocalMemory, arr::AbstractArray)
CPULocalArray{A}(mem.device, [copy(arr) for _ in 1:(mem.device.n)])
function transfer(mem::CPULocalMemory, arr::AbstractArray)
CPULocalArray{typeof(arr)}(mem.device, [copy(arr) for _ in 1:(mem.device.n)])
end
function transfer(task::CPUThread, arr::CPULocalArray)
if get_device(task) === arr.device
if get_device(task) == arr.device
temp = arr.data[task.tid]
return temp
else
Expand Down
2 changes: 1 addition & 1 deletion src/environment.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ variable names in the generated code of the executing environment.
namespace::Namespace = Namespace()
preamble::Vector{Any} = []
epilogue::Vector{Any} = []
task = VirtualSerial()
task = VirtualSerialTask()
end

"""
Expand Down
2 changes: 1 addition & 1 deletion src/tensors/levels/element_levels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ end

postype(::Type{<:ElementLevel{Vf,Tv,Tp}}) where {Vf,Tv,Tp} = Tp

function transfer(lvl::ElementLevel{Vf,Tv,Tp}, device, style) where {Vf,Tv,Tp}
function transfer(device, lvl::ElementLevel{Vf,Tv,Tp}) where {Vf,Tv,Tp}
return ElementLevel{Vf,Tv,Tp}(transfer(device, lvl.val))
end

Expand Down
2 changes: 1 addition & 1 deletion src/tensors/levels/separate_levels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ countstored_level(lvl::SeparateLevel, pos) = pos

mutable struct VirtualSeparateLevel <: AbstractVirtualLevel
tag
lvl # stand in for the sublevel for virutal resize, etc.
lvl # stand in for the sublevel for virtual resize, etc.
val
Tv
Lvl
Expand Down
Loading
Loading