Skip to content

Commit

Permalink
Mostly more type-stability improvements, some algorithmic too
Browse files Browse the repository at this point in the history
  • Loading branch information
halleysfifthinc committed Aug 8, 2024
1 parent c08e9cf commit 43076ea
Show file tree
Hide file tree
Showing 4 changed files with 149 additions and 84 deletions.
15 changes: 13 additions & 2 deletions src/groups.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,22 @@ Base.haskey(g::Group, key) = haskey(g.params, key)

# TODO: Add get! method?
function Base.get(g::Group, key, default)
_key = key isa Tuple ? last(key) : key
return haskey(g, key) ? g[key] : default
end
function Base.get(g::Group, key::Tuple{Type,Symbol}, default::T) where T
_key = last(key)
return haskey(g, _key) ? g[key...] : default
end

Base.show(io::IO, g::Group) = show(io, keys(g.params))

function Base.show(io::IO, g::Group)
print(io, "Group(:$(g.name)) ")
if isempty(keys(g.params))
print(io, "[]")
else
print(io, keys(g.params))
end
end

function Base.show(io::IO, ::MIME"text/plain", g::Group)
print(io, "Group(:$(g.name))")
Expand Down
53 changes: 31 additions & 22 deletions src/parameters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,10 @@ function Base.hash(p::Parameter, h::UInt)
end

gid(p::Parameter) = abs(p.gid)
name(p::Parameter{ArrayParameter{T,N}}) where {T,N} = getfield(p, :name)
name(p::Parameter{StringParameter}) = getfield(p, :name)
name(p::Parameter{ScalarParameter{T}}) where T = getfield(p, :name)

_position(p::Parameter{ArrayParameter{T,N}}) where {T,N} = getfield(p, :pos)
_position(p::Parameter{StringParameter}) = getfield(p, :pos)
_position(p::Parameter{ScalarParameter{T}}) where T = getfield(p, :pos)
Expand Down Expand Up @@ -157,7 +161,7 @@ function Base.show(io::IO, ::MIME"text/plain", p::Parameter{P}) where P
print(io, "\n ", p.payload.data)
end

function _elsize(::Parameter{P}) where P <: Union{StringParameter,ScalarParameter{String}}
function _elsize(::Parameter{P}) where P <: Union{StringParameter,ScalarParameter{String},ArrayParameter{String}}
return -1
end

Expand Down Expand Up @@ -191,7 +195,7 @@ function _size(p::Parameter{ArrayParameter{T,N}}) where {T,N}
return size(p.payload.data)
end

function readparam(io::IOStream, ::Type{END}) where {END<:AbstractEndian}
function readparam(io::IO, ::Type{END}) where {END<:AbstractEndian}
pos = position(io)
nl = read(io, Int8)
@assert nl != 0
Expand All @@ -200,10 +204,10 @@ function readparam(io::IOStream, ::Type{END}) where {END<:AbstractEndian}
@assert gid != 0
_name = read(io, abs(nl))
@assert any(!iscntrlChar, _name)
name = Symbol(replace(strip(transcode(String, copy(_name))), r"[^a-zA-Z0-9_]" => '_'))
name = Symbol(replace(strip(transcode(String, view(_name, :))), r"[^a-zA-Z0-9_]" => '_'))

@debug "Parameter $name at $pos has unofficially supported characters.
Unexpected results may occur" maxlog=occursin(r"[^a-zA-Z0-9_ ]", transcode(String, copy(_name)))
# @debug "Parameter $name at $pos has unofficially supported characters.
# Unexpected results may occur" maxlog=occursin(r"[^a-zA-Z0-9_ ]", transcode(String, copy(_name)))

np = read(io, END(Int16))

Expand All @@ -221,22 +225,23 @@ function readparam(io::IOStream, ::Type{END}) where {END<:AbstractEndian}
end

nd = read(io, UInt8)
local data::AbstractArray
if nd > 0
dims = NTuple{convert(Int, nd),Int}(read!(io, Array{UInt8}(undef, nd)))
dims = (Int.(read!(io, Array{UInt8}(undef, nd)))...,)
data = _readarrayparameter(io, END(T), dims)
else
data = _readscalarparameter(io, END(T))
end

dl = read(io, UInt8)
dl = read(io, UInt8)::UInt8
desc = read(io, dl)

if any(iscntrlChar, desc)
desc = ""
desc = UInt8[]
end

pointer = pos + np + abs(nl) + 2
@debug "wrong pointer in $name" position(io) pointer maxlog=(position(io) != pointer)
# @debug "wrong pointer in $name" position(io) pointer maxlog=(position(io) != pointer)

if nd > 0
if elsize == -1
Expand All @@ -247,28 +252,28 @@ function readparam(io::IOStream, ::Type{END}) where {END<:AbstractEndian}
end
elseif isone(prod(dims))
# In the event of an 'array' parameter with only one element
payload = ScalarParameter(data[1])
payload = ScalarParameter(only(data))
else
payload = ArrayParameter(elsize, nd, dims, data)
end
else
payload = ScalarParameter(data)
payload = ScalarParameter(only(data))
end

return Parameter(pos, gid, locked, np, _name, name, desc, payload)
end

function _readscalarparameter(io::IO, ::Type{END}) where {END<:AbstractEndian}
return read(io, END)
function _readscalarparameter(io::IO, END::Type{<:AbstractEndian{T}}) where {T}
return [read(io, END)]
end

function _readscalarparameter(io::IO, ::Type{<:AbstractEndian{String}})::String
return rstrip(x -> iscntrl(x) || isspace(x), transcode(String, read(io, UInt8)))
return [rstrip(x -> iscntrl(x) || isspace(x), transcode(String, read(io, UInt8)))]
end

function _readarrayparameter(io::IO, ::Type{END}, dims) where {END<:AbstractEndian}
T = eltype(END) <: VaxFloat ? Float32 : eltype(END)
a = Array{T}(undef, dims)
function _readarrayparameter(io::IO, END::Type{<:AbstractEndian{T}}, dims) where {T}
U = eltype(END) <: VaxFloat ? Float32 : eltype(END)
a = Array{U}(undef, dims)
return read!(io, a, END)
end

Expand All @@ -286,14 +291,18 @@ function rstrip_cntrl_null_space(s)
end

function _readarrayparameter(io::IO, ::Type{<:AbstractEndian{String}}, dims)::Array{String}
tdata = Array{UInt8}(undef, dims)
read!(io, tdata)
# tdata::AbstractArray{UInt8} = Array{UInt8}(undef, dims)
# read!(io, tdata)
_tdata = Vector{UInt8}(undef, prod(dims))
read!(io, _tdata)
tdata = reshape(_tdata, dims)

local data::AbstractArray{String}
if length(dims) > 1
_, rdims... = dims
data = Array{String}(undef, rdims)
for ijk in CartesianIndices(data)
data[ijk] = transcode(String, rstrip_cntrl_null_space(@view tdata[:, ijk]))
data = Array{String}(undef, rdims)::Array{String}
for ijk::CartesianIndex in CartesianIndices(data)
data[ijk] = transcode(String, rstrip_cntrl_null_space(@view tdata[:, ijk]))::String
end
else
data = [ transcode(String, rstrip_cntrl_null_space(tdata)) ]
Expand Down
147 changes: 87 additions & 60 deletions src/read.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,18 @@ function calcresiduals(x::AbstractVector, scale)
return (convert.(Int16, x) .% UInt8) .* abs(scale)
end

function calcresiduals!(x::AbstractVector, indices::Vector{T}, scale) where T
@simd ivdep for i in eachindex(x)
function calcresiduals!(x::AbstractVector{T}, indices::Vector{I}, scale) where {T,I}
@inbounds for i in eachindex(x)
if indices[i]
x[i] = (convert(Int16, x[i]) % UInt8) * abs(scale)
else
x[i] = x[i]
end
end

return nothing
end

function readdata(
io::IOStream, h::Header{END}, groups::LittleDict{Symbol,Group{END},Vector{Symbol},Vector{Group{END}}}, ::Type{F}
io::IO, h::Header{END}, groups::LittleDict{Symbol,Group{END},Vector{Symbol},Vector{Group{END}}}, ::Type{F}
) where {END<:AbstractEndian, F}
if iszero(groups[:POINT][Int, :DATA_START]-1)
if !iszero(h.datastart-1)
Expand All @@ -30,9 +28,10 @@ function readdata(
numframes::Int = numpointframes(groups)
nummarkers::Int = groups[:POINT][Int, :USED]
numchannels::Int = groups[:ANALOG][Int, :USED]
pointrate = get(groups[:POINT], (Float32, :RATE), h.pointrate)
if isinteger(get(groups[:ANALOG], (Float32, :RATE), pointrate)/pointrate)
aspf::Int = convert(Int, get(groups[:ANALOG], (Float32, :RATE), pointrate)/pointrate)
pointrate = get(groups[:POINT], (Float32, :RATE), h.pointrate::Float32)::Float32
analograte = get(groups[:ANALOG], (Float32, :RATE), pointrate)::Float32
if isinteger(analograte/pointrate)::Bool
aspf::Int = convert(Int, analograte/pointrate)
else
aspf = h.aspf
end
Expand Down Expand Up @@ -159,7 +158,10 @@ Read the C3D file at `fn`.
function readc3d(fn::AbstractString; paramsonly=false, validate=true,
handle_duplicate_parameters=:keeplast, missingpoints=true, strip_prefixes=false)

handle_duplicate_parameters (:drop, :keeplast, :append_position)
handle_duplicate_parameters (:drop, :keeplast, :append_count) || throw(ArgumentError(
"Invalid `handle_duplicate_parameters`. Got :$handle_duplicate_parameters. Check \
the docstring for valid options."))

io = open(fn, "r")

params_ptr = read(io, UInt8)
Expand Down Expand Up @@ -223,10 +225,6 @@ end

"requires a::Vector{Parameter} sorted by gid; returns views"
function split_filter!(f, a)
# true_is = findall(f, a)
# trues = a[true_is]

# falses = deleteat!(a, true_is)
if isempty(a)
return similar(a, ntuple(_-> 0, ndims(a))), @view a[end+1:end]
end
Expand All @@ -238,13 +236,12 @@ function split_filter!(f, a)
else
@views falses = a[last(true_is)+1:end]
end
# @debug eachindex(a), true_is trues, falses

return trues, falses
end

function isduplicate(x, a; by=identity)
return count(==(by(x)), a) > 1
return count(==(by(x))by, a) > 1
end

function _readparams(io::IO, paramblocks, ::Type{END}, handle_duplicate_parameters::Symbol) where {END}
Expand All @@ -253,8 +250,8 @@ function _readparams(io::IO, paramblocks, ::Type{END}, handle_duplicate_paramete
reset(io)

if !iszero(paramblocks)
gs = Array{Group{END},1}()
ps = Array{Parameter,1}()
gs = Vector{Group{END}}()
ps = Vector{Parameter}()
moreparams = true
fail = 0
np = 0
Expand Down Expand Up @@ -331,32 +328,38 @@ function _readparams(io::IO, paramblocks, ::Type{END}, handle_duplicate_paramete
end

group_names = map(g -> g.name, gs)
if !allunique(group_names)
if !allunique(gid.(gs))
dup_group_names = findduplicates(group_names)
if !isempty(dup_group_names)
# @debug "duplicate group names detected"
if !allunique(gid.(gs)) # Duplicate names with the same GID
for group in gs
if isduplicate(group, gs; by=(x->x.name))
@warn "Multiple groups with the same name \"$(group.name)\" and
group ID `$(gid(group))`. The second duplicate group will be
deleted to keep group names unique (no parameters will be
lost)."

i = findlast(g -> g.name == group, gs)
@assert !isnothing(i)
@assert group !== gs[i]
# i.e. gs[i] is *after* group and therefore deleting will only
# affect future iterations (by not running on the already-handled
# duplicate)

deleteat!(gs, i)
end
group.name dup_group_names || continue

@warn "Multiple groups with the same name \"$(group.name)\" and \
group ID `$(gid(group))`. The second duplicate group will be deleted to keep group names
unique (no parameters will be lost)." _id=stat(fd(io)).desc maxlog=1

# could be more than 1 duplicate
is = findall(==(group.name)(g->g.name), gs)
@assert !isempty(is)
@assert group === gs[first(is)]

# i.e. we're deleting *after* `group` and therefore deleting will
# only affect future iterations (by not running on the
# already-handled duplicate)

deleteat!(gs, is)
end
else
else # Duplicate names with different GIDs
for group in gs
if isduplicate(group, gs; by=(x->x.name))
@warn "Multiple groups with the same name \"$(group.name)\". The
group ID will be appended to the duplicate group names to keep
group names unique."
group.name = Symbol(group.name, "_", gid(group))
group.name dup_group_names || continue

# @debug group isduplicate(group, gs; by=(g->g.name))
@warn "Multiple groups with the same name \"$(group.name)\". The \
group ID will be appended to the duplicate group names to keep group names unique." _id=stat(fd(io)).desc maxlog=1
dups = findall(==(group.name)(g->g.name), gs)
for i in dups
gs[i].name = Symbol(gs[i].name, "_", gid(gs[i]))
end
end
end
Expand All @@ -368,40 +371,64 @@ function _readparams(io::IO, paramblocks, ::Type{END}, handle_duplicate_paramete
psv = @view ps[begin:end]

for group in sort(gs; by=gid)
# @debug "" group.name abs(group.gid)
group_params, psv = split_filter!(gid, psv)
if isempty(group_params)
# @debug "group $(group.name) is empty" group_params, psv
@debug "group $(group.name) is empty" group_params, psv
break
end
!allunique(group_params) || unique!(identity, group_params; seen=Set{Parameter}()) # remove literal duplicates
unique!(identity, group_params; seen=Set{Parameter}()) # remove literal duplicates

# check for params with duplicate names
param_names = map(g -> g.name, group_params)
# @debug param_names
uniq_names = unique(identity, param_names; seen=Set{Symbol}())
if length(uniq_names) < length(param_names) # !allunique(param_names)
param_names = map(name, group_params)
dup_names = findduplicates(param_names)
if !isempty(dup_names)
# @debug "duplicate parameters detected in" group
if handle_duplicate_parameters === :drop
unique!(p -> p.name, group_params; seen=Set{Parameter}())
# @debug "dropping duplicates"
unique!(name, group_params)
elseif handle_duplicate_parameters === :keeplast
for name in uniq_names
@views dups = findall(p -> p.name == name, group_params)[1:end-1]
deleteat!(group_params, dups)
# @debug "keeping last duplicate" issorted(_position.(group_params))
dups = similar(BitVector, axes(group_params, 1))
dups .= false
_dups = similar(dups)
for _name in dup_names
for i in eachindex(_dups)
_dups[i] = name(group_params[i]) === name
end
# if count(_dups) > 1
# @debug "found `duplicates` and keeping `last(duplicates)`" group_params[_dups], group_params[findlast(_dups)]
# end
count(_dups) > 1 || continue
i = findlast(_dups)
_dups[i] = false
dups .|= _dups
end
elseif handle_duplicate_parameters === :append_position
for name in uniq_names
dups = findall(p -> p.name == name, group_params)
for dup in dups
group_params[dup].name =
Symbol(group_params[dup].name, "_", group_params[dup].pos)
deleteat!(group_params, dups)
elseif handle_duplicate_parameters === :append_count
# @debug "appending duplicate count to duplicate names; $(length(dup_names)) parameters with duplicated names"
dups = similar(BitVector, axes(group_params, 1))
dups .= false
_dups = similar(dups)
for _name in dup_names
for i in eachindex(_dups)
_dups[i] = name(group_params[i]) === name
end
count(_dups) > 1 || continue
# @debug "$(count(_dups)) duplicate parameters with name $(_name)"
dups .|= _dups
end
cnts = Dict{Symbol,Int}(zip(dup_names, zeros(Int, axes(dup_names,1))))
for i in eachindex(dups)
dups[i] || continue
param = group_params[i]
cnt = cnts[param.name] += 1
# @debug "$(param.name) => $(Symbol(param.name, "_", cnt))"
param.name = Symbol(param.name, "_", cnt)
end
end
param_names = map(name, group_params)
end

sort!(group_params; by=_position)
param_names = map(g -> g.name, group_params)

sizehint!(group.params, length(group_params))
OrderedCollections.add_new!.(Ref(group.params), param_names, group_params)
end
Expand Down
Loading

0 comments on commit 43076ea

Please sign in to comment.