Skip to content

Commit 913e48e

Browse files
author
Janis Erdmanis
committed
fixes for OpenSSLGroups
1 parent d6e0744 commit 913e48e

File tree

4 files changed

+64
-45
lines changed

4 files changed

+64
-45
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "CryptoGroups"
22
uuid = "bc997328-bedd-407e-bcd3-5758e064a52d"
33
authors = ["Janis Erdmanis <[email protected]>"]
4-
version = "0.5.0"
4+
version = "0.5.1"
55

66
[deps]
77
CryptoPRG = "d846c407-34c1-46cb-aa27-d51818cc05e2"

src/Curves/ecpoint.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,11 @@ Base.:-(u::P, v::P) where P <: AbstractPoint = u + (-v)
3535

3636
Base.isless(x::P, y::P) where P <: AbstractPoint = gx(x) == gx(y) ? gx(x) < gx(y) : gy(x) < gy(y)
3737

38-
function validate(x::AbstractPoint, order::Integer, cofactor::Integer)
38+
function validate(x::P, order::Integer, cofactor::Integer) where P <: AbstractPoint
3939

4040
oncurve(x) || throw(ArgumentError("Point is not in curve"))
41-
x * cofactor != zero(x) || throw(ArgumentError("Point is in cofactor subgroup"))
41+
#x * cofactor != zero(P) || throw(ArgumentError("Point is in cofactor subgroup"))
42+
!iszero(x * cofactor) || throw(ArgumentError("Point is in cofactor subgroup"))
4243

4344
return
4445
end
@@ -130,7 +131,7 @@ name(::Type{ECPoint{P, S}}) where {P <: AbstractPoint, S} = isnothing(S.name) ?
130131

131132
eq(::Type{ECPoint{P, S}}) where {P <: AbstractPoint, S} = eq(P)
132133
field(::Type{ECPoint{P, S}}) where {P <: AbstractPoint, S} = field(P)
133-
134+
field(::Type{ECPoint{P}}) where {P <: AbstractPoint} = field(P)
134135

135136
"""
136137
zero(::Union{P, Type{P}}) where P <: AbstractPoint

src/macros.jl

Lines changed: 58 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
using .Specs: modp_spec
22

3-
# TODO: Add support for @PGroup{p = _p, q = _q} where _q and _p are defined out of the scope
43
macro PGroup(expr)
54
if expr.head == :braces
65
if length(expr.args) == 1 && !(expr.args[1] isa Expr)
@@ -9,15 +8,18 @@ macro PGroup(expr)
98
spec = modp_spec(name)
109
group = concretize_type(PGroup, spec; name)
1110
return group
11+
1212
else
13-
# Two-argument case: @PGroup{p=23, q=11}
13+
# Two-argument case: @PGroup{p=23, q=11} or @PGroup{p=my_p, q=my_q}
1414
p = q = nothing
1515
for arg in expr.args
1616
if arg isa Expr && arg.head == :(=)
17-
if arg.args[1] == :p
18-
p = arg.args[2]
19-
elseif arg.args[1] == :q
20-
q = arg.args[2]
17+
lhs = arg.args[1]
18+
rhs = arg.args[2]
19+
if lhs == :p
20+
p = rhs
21+
elseif lhs == :q
22+
q = rhs
2123
end
2224
end
2325
end
@@ -27,16 +29,22 @@ macro PGroup(expr)
2729
error("Both p and q must be specified in @PGroup{p=..., q=...}")
2830
end
2931

30-
spec = MODP(p; q)
31-
group = concretize_type(PGroup, spec)
32-
return group
32+
# Properly escape both p and q values
33+
return quote
34+
local p_val = $(esc(p))
35+
local q_val = $(esc(q))
36+
local spec = MODP(p_val; q=q_val)
37+
concretize_type(
38+
PGroup,
39+
spec
40+
)
41+
end
3342
end
3443
else
3544
error("Invalid syntax. Use @PGroup{p=..., q=...} or @PGroup{some_name}")
3645
end
3746
end
3847

39-
4048
Base.show(io::IO, ::Type{PGroup}) = print(io, "PGroup")
4149

4250
function Base.show(io::IO, ::Type{G}) where G <: PGroup
@@ -66,20 +74,51 @@ end
6674

6775

6876
macro ECGroup(expr)
69-
if expr.head == :braces && length(expr.args) == 1 && !(expr.args[1] isa Expr)
70-
# Single argument case: @PGroup{some_name}
71-
some_name = expr.args[1]
72-
73-
# If the curve can't be found error here
74-
spec = curve(some_name)
75-
group = concretize_type(ECGroup, spec)
77+
if expr.head == :braces && length(expr.args) == 1
78+
arg = expr.args[1]
79+
point_expr = Expr(:macrocall, Symbol("@ECPoint"), LineNumberNode(@__LINE__),
80+
Expr(:braces, arg))
81+
# Use __module__ to get the ECGroup type from the defining module
82+
return :(ECGroup{$(esc(point_expr))})
83+
else
84+
error("Invalid syntax. Use @ECGroup{curve_name} or @ECGroup{Module.curve_name}")
85+
end
86+
end
7687

77-
return group
88+
# First, let's modify @ECPoint to ensure it handles symbol quoting correctly
89+
macro ECPoint(expr)
90+
if expr.head == :braces && length(expr.args) == 1
91+
arg = expr.args[1]
92+
93+
# Handle module-qualified names (e.g., OpenSSLGroups.SecP256k1)
94+
if arg isa Expr && arg.head == :.
95+
return quote
96+
local P = $(esc(arg))
97+
concretize_type(ECPoint{P}, order(P), cofactor(P); name = nameof(P))
98+
end
99+
# Handle simple symbols
100+
elseif arg isa Symbol
101+
# Important: Use QuoteNode here for the isdefined check
102+
return quote
103+
if isdefined($(__module__), $(QuoteNode(arg)))
104+
# If defined, use the escaped symbol to access its value
105+
local P = $(esc(arg))
106+
concretize_type(ECPoint{P}, order(P), cofactor(P); name = nameof(P))
107+
else
108+
# If not defined, treat it as a curve name
109+
local spec = curve($(QuoteNode(arg)))
110+
concretize_type(ECPoint, spec)
111+
end
112+
end
113+
else
114+
error("Invalid syntax. Use @ECPoint{curve_name} or @ECPoint{Module.curve_name}")
115+
end
78116
else
79-
error("Invalid syntax. Use @ECGroup{curve_name}")
117+
error("Invalid syntax. Use @ECPoint{curve_name} or @ECPoint{Module.curve_name}")
80118
end
81119
end
82120

121+
83122
function Base.show(io::IO, g::G) where G <: ECGroup
84123
show(io, G)
85124
print(io, "(")
@@ -101,23 +140,6 @@ function Base.show(io::IO, ::Type{G}) where G <: ECGroup
101140
end
102141
end
103142

104-
105-
macro ECPoint(expr)
106-
if expr.head == :braces && length(expr.args) == 1 && !(expr.args[1] isa Expr)
107-
# Single argument case: @PGroup{some_name}
108-
some_name = expr.args[1]
109-
110-
spec = curve(some_name)
111-
# If the curve can't be found error here
112-
_curve = concretize_type(ECPoint, spec)
113-
114-
return _curve
115-
else
116-
error("Invalid syntax. Use @ECPoint{curve_name}")
117-
end
118-
end
119-
120-
121143
### May need to do epoint seperatelly
122144
function Base.show(io::IO, ::Type{P}) where P <: ECPoint
123145
if @isdefined P
@@ -131,7 +153,6 @@ function Base.show(io::IO, ::Type{P}) where P <: ECPoint
131153
end
132154
end
133155

134-
135156
function Base.show(io::IO, p::P) where P <: ECPoint
136157
show(io, P)
137158
print(io, "(")
@@ -144,6 +165,3 @@ end
144165
function Base.display(::Type{P}) where P <: ECPoint
145166
show(P)
146167
end
147-
148-
149-

src/spec.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,5 +158,5 @@ spec(g::ECGroup) = spec(g.x)
158158
spec(::Type{G}) where G <: PGroup = MODP(; p = modulus(G), q = order(G))
159159

160160
(::Type{P})() where P <: ECPoint = P(generator(curve(name(P))))
161-
(::Type{G})() where G <: ECGroup = G(generator(curve(name(G))))
161+
(::Type{ECGroup{P}})() where P <: ECPoint = ECGroup{P}(P())
162162
(::Type{G})() where G <: PGroup = G(generator(modp_spec(name(G))))

0 commit comments

Comments
 (0)