Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 173 additions & 0 deletions edb/lib/std/26-bitwisefuncs.edgeql
Original file line number Diff line number Diff line change
Expand Up @@ -366,3 +366,176 @@ std::bit_count(val: std::int64) -> std::int64
SELECT bit_count(val::bit(64))
$$;
};


## Bitwise bytes functions
## ----------------------

CREATE FUNCTION
std::bytes_and(a: bytes, b: bytes) -> bytes
{
CREATE ANNOTATION std::description :=
'Bitwise AND operator for bytes.';
SET volatility := 'Immutable';
USING SQL $$
SELECT CASE
WHEN octet_length($1) != octet_length($2) THEN
edgedb_VER.raise(
NULL::bytea,
'invalid_parameter_value',
msg => 'bytes_and: bytes must be of equal length'
)
ELSE (
WITH splits AS (
SELECT generate_series(1, octet_length($1)) AS i
)
SELECT string_agg(
decode(
lpad(
to_hex(
(get_byte($1, i-1) & get_byte($2, i-1))::int
),
2,
'0'
),
'hex'
),
''
)::bytea
FROM splits
)
END
$$;
};

CREATE FUNCTION
std::bytes_or(a: bytes, b: bytes) -> bytes
{
CREATE ANNOTATION std::description :=
'Bitwise OR operator for bytes.';
SET volatility := 'Immutable';
USING SQL $$
SELECT CASE
WHEN octet_length($1) != octet_length($2) THEN
edgedb_VER.raise(
NULL::bytea,
'invalid_parameter_value',
msg => 'bytes_or: bytes must be of equal length'
)
ELSE (
WITH splits AS (
SELECT generate_series(1, octet_length($1)) AS i
)
SELECT string_agg(
decode(
lpad(
to_hex(
(get_byte($1, i-1) | get_byte($2, i-1))::int
),
2,
'0'
),
'hex'
),
''
)::bytea
FROM splits
)
END
$$;
};

CREATE FUNCTION
std::bytes_xor(a: bytes, b: bytes) -> bytes
{
CREATE ANNOTATION std::description :=
'Bitwise XOR operator for bytes.';
SET volatility := 'Immutable';
USING SQL $$
SELECT CASE
WHEN octet_length($1) != octet_length($2) THEN
edgedb_VER.raise(
NULL::bytea,
'invalid_parameter_value',
msg => 'bytes_xor: bytes must be of equal length'
)
ELSE (
WITH splits AS (
SELECT generate_series(1, octet_length($1)) AS i
)
SELECT string_agg(
decode(
lpad(
to_hex(
(get_byte($1, i-1) # get_byte($2, i-1))::int
),
2,
'0'
),
'hex'
),
''
)::bytea
FROM splits
)
END
$$;
};

CREATE FUNCTION
std::bytes_not(a: bytes) -> bytes
{
CREATE ANNOTATION std::description :=
'Bitwise NOT operator for bytes.';
SET volatility := 'Immutable';
USING SQL $$
WITH splits AS (
SELECT generate_series(1, octet_length($1)) AS i
)
SELECT string_agg(
decode(
lpad(
to_hex(
(~get_byte($1, i-1) & 255)::int
),
2,
'0'
),
'hex'
),
''
)::bytea
FROM splits
$$;
};

CREATE FUNCTION
std::bytes_overlap(l: std::bytes, r: std::bytes) -> std::bool
{
CREATE ANNOTATION std::description :=
'Check if two bytes have any overlapping bits (bitwise AND is non-zero).';
SET volatility := 'Immutable';
USING SQL $$
SELECT (
CASE
WHEN octet_length(l) != octet_length(r) THEN
edgedb_VER.raise(
NULL::bool,
'invalid_parameter_value',
msg => (
'bytes_overlap(): bytes must be of equal length'
)
)
ELSE (
WITH splits AS (
SELECT generate_series(1, octet_length(l)) AS i
)
SELECT EXISTS (
SELECT 1 FROM splits
WHERE (get_byte(l, i-1) & get_byte(r, i-1)) != 0
)
)
END
)
$$;
};
199 changes: 199 additions & 0 deletions tests/test_edgeql_functions_bitwise.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
import pytest
import base64
from edb.testbase import server as tb
from gel.errors import InvalidValueError
import json

class TestEdgeQLBitwiseBytesFunctions(tb.QueryTestCase):

async def _test_binary(self, query, expected):
# Test binary format using raw connection
result = await self.con.query_single(query)
assert result == expected

async def _test_json(self, query, expected):
# Test JSON format using JSON protocol
result = await self.con._fetchall_json(query)
parsed = json.loads(result)
assert parsed[0] == expected

async def test_bytes_and(self):
# Test basic AND operation
await self._test_binary(
r'''SELECT bytes_and(b'\xFF\x00', b'\x0F\x0F')''',
b'\x0F\x00'
)
await self._test_json(
r'''SELECT bytes_and(b'\xFF\x00', b'\x0F\x0F')''',
'DwA='
)

# Test with zeros
await self._test_binary(
r'''SELECT bytes_and(b'\x00\x00', b'\xFF\xFF')''',
b'\x00\x00'
)

# Test with all ones
await self._test_binary(
r'''SELECT bytes_and(b'\xFF\xFF', b'\xFF\xFF')''',
b'\xFF\xFF'
)

# Test error on different lengths
async with self.assertRaisesRegexTx(
InvalidValueError,
"bytes_and: bytes must be of equal length"
):
await self.con.query(
r'''SELECT bytes_and(b'\xFF', b'\xFF\xFF')'''
)

async def test_bytes_or(self):
# Test basic OR operation
await self._test_binary(
r'''SELECT bytes_or(b'\xF0\x00', b'\x0F\x0F')''',
b'\xFF\x0F'
)
await self._test_json(
r'''SELECT bytes_or(b'\xF0\x00', b'\x0F\x0F')''',
'/w8='
)

# Test with zeros
await self._test_binary(
r'''SELECT bytes_or(b'\x00\x00', b'\x00\x00')''',
b'\x00\x00'
)

# Test with all ones
await self._test_binary(
r'''SELECT bytes_or(b'\xFF\xFF', b'\x00\x00')''',
b'\xFF\xFF'
)

# Test error on different lengths
async with self.assertRaisesRegexTx(
InvalidValueError,
"bytes_or: bytes must be of equal length"
):
await self.con.query(
r'''SELECT bytes_or(b'\xFF', b'\xFF\xFF')'''
)

async def test_bytes_xor(self):
# Test basic XOR operation
await self._test_binary(
r'''SELECT bytes_xor(b'\xFF\x00', b'\x0F\x0F')''',
b'\xF0\x0F'
)
await self._test_json(
r'''SELECT bytes_xor(b'\xFF\x00', b'\x0F\x0F')''',
'8A8='
)

# Test with zeros
await self._test_binary(
r'''SELECT bytes_xor(b'\xFF\xFF', b'\x00\x00')''',
b'\xFF\xFF'
)

# Test with same values (should be zero)
await self._test_binary(
r'''SELECT bytes_xor(b'\xFF\xFF', b'\xFF\xFF')''',
b'\x00\x00'
)

# Test error on different lengths
async with self.assertRaisesRegexTx(
InvalidValueError,
"bytes_xor: bytes must be of equal length"
):
await self.con.query(
r'''SELECT bytes_xor(b'\xFF', b'\xFF\xFF')'''
)

async def test_bytes_not(self):
# Test basic NOT operation
await self._test_binary(
r'''SELECT bytes_not(b'\xFF\x00')''',
b'\x00\xFF'
)
await self._test_json(
r'''SELECT bytes_not(b'\xFF\x00')''',
'AP8='
)

# Test with all zeros
await self._test_binary(
r'''SELECT bytes_not(b'\x00\x00')''',
b'\xFF\xFF'
)

# Test with all ones
await self._test_binary(
r'''SELECT bytes_not(b'\xFF\xFF')''',
b'\x00\x00'
)

async def test_bytes_overlap(self):
# Test overlapping bytes
await self._test_binary(
r'''SELECT bytes_overlap(b'\xFF\x00', b'\xFF\xFF')''',
True
)

# Test non-overlapping bytes
await self._test_binary(
r'''SELECT bytes_overlap(b'\xF0\x00', b'\x0F\x00')''',
False
)

# Test with zeros
await self._test_binary(
r'''SELECT bytes_overlap(b'\x00\x00', b'\xFF\xFF')''',
False
)

# Test with all ones
await self._test_binary(
r'''SELECT bytes_overlap(b'\xFF\xFF', b'\xFF\xFF')''',
True
)

# Test error on different lengths
async with self.assertRaisesRegexTx(
InvalidValueError,
r"bytes_overlap\(\): bytes must be of equal length"
):
await self.con.query(
r'''SELECT bytes_overlap(b'\xFF', b'\xFF\xFF')'''
)

async def test_bytes_combinations(self):
# Test combining multiple operations
await self._test_binary(
r'''SELECT bytes_and(
bytes_or(b'\xF0\x00', b'\x0F\x0F'),
bytes_xor(b'\xFF\x00', b'\x0F\x0F')
)''',
b'\xF0\x0F'
)

# Test with NOT operations
await self._test_binary(
r'''SELECT bytes_and(
bytes_not(b'\x00\xFF'),
b'\xFF\x00'
)''',
b'\xFF\x00'
)

# Test overlap with combined operations
await self._test_binary(
r'''SELECT bytes_overlap(
bytes_and(b'\xFF\x00', b'\x0F\x0F'),
bytes_or(b'\xF0\x00', b'\x0F\x0F')
)''',
True
)
Loading