From 6f3c73c88a8e724b1b420637ab460dd9d4be0d3a Mon Sep 17 00:00:00 2001 From: VimT Date: Sat, 1 Jun 2024 17:32:12 +0800 Subject: [PATCH] Feature: support custom int/float type (#7) * support custom int/float type --- README.md | 52 +++ examples/csv_to_mmdb.py | 21 +- mmdb_writer.py | 449 +++++++++++++++------ setup.py | 32 +- tests/clients.py | 131 ++++++ tests/clients/go/go.mod | 7 + tests/clients/go/main.go | 59 +++ tests/clients/java/.gitignore | 38 ++ tests/clients/java/pom.xml | 62 +++ tests/clients/java/src/main/java/Main.java | 87 ++++ tests/record.py | 117 ++++++ tests/test.py | 85 ++-- 12 files changed, 959 insertions(+), 181 deletions(-) create mode 100644 tests/clients.py create mode 100644 tests/clients/go/go.mod create mode 100644 tests/clients/go/main.go create mode 100644 tests/clients/java/.gitignore create mode 100644 tests/clients/java/pom.xml create mode 100644 tests/clients/java/src/main/java/Main.java create mode 100644 tests/record.py diff --git a/README.md b/README.md index d111b7d..a5b88e6 100644 --- a/README.md +++ b/README.md @@ -26,6 +26,58 @@ assert r == {'country': 'COUNTRY', 'isp': 'ISP'} ## Examples see [csv_to_mmdb.py](./examples/csv_to_mmdb.py) +Here is a professional and clear translation of the README.md section from Chinese into English: + +## Using the Java Client + +### TLDR + +When generating an MMDB file for use with the Java client, you must specify the `int_type`: + +```python +from mmdb_writer import MMDBWriter + +writer = MMDBWriter(int_type='int32') +``` + +Alternatively, you can explicitly specify data types using the [Type Enforcement](#type-enforcement) section. + +### Underlying Principles + +In Java, when deserializing to a structure, the numeric types will use the original MMDB numeric types. The specific +conversion relationships are as follows: + +| mmdb type | java type | +|--------------|------------| +| float (15) | Float | +| double (3) | Double | +| int32 (8) | Integer | +| uint16 (5) | Integer | +| uint32 (6) | Long | +| uint64 (9) | BigInteger | +| uint128 (10) | BigInteger | + +When using the Python writer to generate an MMDB file, by default, it converts integers to the corresponding MMDB type +based on the size of the `int`. For instance, `int(1)` would convert to `uint16`, and `int(2**16+1)` would convert +to `uint32`. This may cause deserialization failures in Java clients. Therefore, it is necessary to specify +the `int_type` parameter when generating MMDB files to define the numeric type accurately. + +## Type Enforcement + +MMDB supports a variety of numeric types such as `int32`, `uint16`, `uint32`, `uint64`, `uint128` for integers, +and `f32`, `f64` for floating points, while Python only has one integer type and one float type (actually `f64`). + +Therefore, when generating an MMDB file, you need to specify the `int_type` parameter to define the numeric type of the +MMDB file. The behaviors for different `int_type` settings are: + +| int_type | Behavior | +|----------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| auto (default) | Automatically selects the MMDB numeric type based on the value size.
Rules:
`int32` for value < 0
`uint16` for 0 <= value < 2^16
`uint32` for 2^16 <= value < 2^32
`uint64` for 2^32 <= value < 2^64
`uint128` for value >= 2^64. | +| i32 | Stores all integer types as `int32`. | +| u16 | Stores all integer types as `uint16`. | +| u32 | Stores all integer types as `uint32`. | +| u64 | Stores all integer types as `uint64`. | +| u128 | Stores all integer types as `uint128`. | ## Reference: diff --git a/examples/csv_to_mmdb.py b/examples/csv_to_mmdb.py index b4a4f75..28fdcd3 100644 --- a/examples/csv_to_mmdb.py +++ b/examples/csv_to_mmdb.py @@ -8,25 +8,30 @@ def main(): - writer = MMDBWriter(4, 'Test.GeoIP', languages=['EN'], description="Test IP library") + writer = MMDBWriter( + 4, "Test.GeoIP", languages=["EN"], description="Test IP library" + ) data = defaultdict(list) # merge cidr - with open('fake_ip_info.csv', 'r') as f: + with open("fake_ip_info.csv", "r") as f: reader = csv.DictReader(f) for line in reader: - data[(line['country'], line['isp'])].append(IPNetwork(f'{line["ip"]}/{line["prefixlen"]}')) + data[(line["country"], line["isp"])].append( + IPNetwork(f'{line["ip"]}/{line["prefixlen"]}') + ) for index, cidrs in data.items(): - writer.insert_network(IPSet(cidrs), {'country': index[0], 'isp': index[1]}) - writer.to_db_file('fake_ip_library.mmdb') + writer.insert_network(IPSet(cidrs), {"country": index[0], "isp": index[1]}) + writer.to_db_file("fake_ip_library.mmdb") def test_read(): import maxminddb - m = maxminddb.open_database('fake_ip_library.mmdb') - r = m.get('3.1.1.1') + + m = maxminddb.open_database("fake_ip_library.mmdb") + r = m.get("3.1.1.1") print(r) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/mmdb_writer.py b/mmdb_writer.py index 8b57540..70949bb 100644 --- a/mmdb_writer.py +++ b/mmdb_writer.py @@ -1,20 +1,100 @@ # coding: utf-8 -__version__ = '0.1.1' +__version__ = "0.2.1" import logging import math import struct import time -from typing import Union from decimal import Decimal +from enum import IntEnum +from typing import Union, List, Dict, Literal from netaddr import IPSet, IPNetwork -MMDBType = Union[dict, list, str, bytes, int, bool] + +class MmdbBaseType(object): + def __init__(self, value): + self.value = value + + +# type hint +class MmdbF32(MmdbBaseType): + def __init__(self, value: float): + super().__init__(value) + + +class MmdbF64(MmdbBaseType): + def __init__(self, value: Union[float | Decimal]): + super().__init__(value) + + +class MmdbI32(MmdbBaseType): + def __init__(self, value: int): + super().__init__(value) + + +class MmdbU16(MmdbBaseType): + def __init__(self, value: int): + super().__init__(value) + + +class MmdbU32(MmdbBaseType): + def __init__(self, value: int): + super().__init__(value) + + +class MmdbU64(MmdbBaseType): + def __init__(self, value: int): + super().__init__(value) + + +class MmdbU128(MmdbBaseType): + def __init__(self, value: int): + super().__init__(value) + + +MMDBType = Union[ + dict, + list, + str, + bytes, + int, + bool, + MmdbF32, + MmdbF64, + MmdbI32, + MmdbU16, + MmdbU32, + MmdbU64, + MmdbU128, +] logger = logging.getLogger(__name__) -METADATA_MAGIC = b'\xab\xcd\xefMaxMind.com' +METADATA_MAGIC = b"\xab\xcd\xefMaxMind.com" + + +class MMDBTypeID(IntEnum): + POINTER = 1 + STRING = 2 + DOUBLE = 3 + BYTES = 4 + UINT16 = 5 + UINT32 = 6 + MAP = 7 + INT32 = 8 + UINT64 = 9 + UINT128 = 10 + ARRAY = 11 + DATA_CACHE = 12 + END_MARKER = 13 + BOOLEAN = 14 + FLOAT = 15 + + +UINT16_MAX = 0xFFFF +UINT32_MAX = 0xFFFFFFFF +UINT64_MAX = 0xFFFFFFFFFFFFFFFF class SearchTreeNode(object): @@ -53,55 +133,97 @@ def __repr__(self): __str__ = __repr__ +IntType = Union[ + Literal["auto", "u16", "u32", "u64", "u128", "i32"] + | MmdbU16 + | MmdbU32 + | MmdbU64 + | MmdbU128 + | MmdbI32 +] +FloatType = Union[Literal["f32", "f64"] | MmdbF32 | MmdbF64] + + class Encoder(object): + def __init__( + self, cache=True, int_type: IntType = "auto", float_type: FloatType = "f64" + ): + self.cache = cache + self.int_type = int_type + self.float_type = float_type - def __init__(self, cache=True): self.data_cache = {} self.data_list = [] self.data_pointer = 0 - - self.cache = cache + self._python_type_id = { + float: MMDBTypeID.DOUBLE, + bool: MMDBTypeID.BOOLEAN, + list: MMDBTypeID.ARRAY, + dict: MMDBTypeID.MAP, + bytes: MMDBTypeID.BYTES, + str: MMDBTypeID.STRING, + MmdbF32: MMDBTypeID.FLOAT, + MmdbF64: MMDBTypeID.DOUBLE, + MmdbI32: MMDBTypeID.INT32, + MmdbU16: MMDBTypeID.UINT16, + MmdbU32: MMDBTypeID.UINT32, + MmdbU64: MMDBTypeID.UINT64, + MmdbU128: MMDBTypeID.UINT128, + } def _encode_pointer(self, value): pointer = value if pointer >= 134744064: - res = struct.pack('>BI', 0x38, pointer) + res = struct.pack(">BI", 0x38, pointer) elif pointer >= 526336: pointer -= 526336 - res = struct.pack('>BBBB', 0x30 + ((pointer >> 24) & 0x07), - (pointer >> 16) & 0xff, (pointer >> 8) & 0xff, - pointer & 0xff) + res = struct.pack( + ">BBBB", + 0x30 + ((pointer >> 24) & 0x07), + (pointer >> 16) & 0xFF, + (pointer >> 8) & 0xFF, + pointer & 0xFF, + ) elif pointer >= 2048: pointer -= 2048 - res = struct.pack('>BBB', 0x28 + ((pointer >> 16) & 0x07), - (pointer >> 8) & 0xff, pointer & 0xff) + res = struct.pack( + ">BBB", + 0x28 + ((pointer >> 16) & 0x07), + (pointer >> 8) & 0xFF, + pointer & 0xFF, + ) else: - res = struct.pack('>BB', 0x20 + ((pointer >> 8) & 0x07), - pointer & 0xff) + res = struct.pack(">BB", 0x20 + ((pointer >> 8) & 0x07), pointer & 0xFF) return res def _encode_utf8_string(self, value): - encoded_value = value.encode('utf-8') - res = self._make_header(2, len(encoded_value)) + encoded_value = value.encode("utf-8") + res = self._make_header(MMDBTypeID.STRING, len(encoded_value)) res += encoded_value return res def _encode_bytes(self, value): - return self._make_header(4, len(value)) + value + return self._make_header(MMDBTypeID.BYTES, len(value)) + value def _encode_uint(self, type_id, max_len): + value_max = 2 ** (max_len * 8) + def _encode_unsigned_value(value): - res = b'' + if value < 0 or value >= value_max: + raise ValueError( + f"encode uint{max_len * 8} fail: {value} not in range(0, {value_max})" + ) + res = b"" while value != 0 and len(res) < max_len: - res = struct.pack('>B', value & 0xff) + res + res = struct.pack(">B", value & 0xFF) + res value = value >> 8 return self._make_header(type_id, len(res)) + res return _encode_unsigned_value def _encode_map(self, value): - res = self._make_header(7, len(value)) + res = self._make_header(MMDBTypeID.MAP, len(value)) for k, v in list(value.items()): # Keys are always stored by value. res += self.encode(k) @@ -109,13 +231,13 @@ def _encode_map(self, value): return res def _encode_array(self, value): - res = self._make_header(11, len(value)) + res = self._make_header(MMDBTypeID.ARRAY, len(value)) for k in value: res += self.encode(k) return res def _encode_boolean(self, value): - return self._make_header(14, 1 if value else 0) + return self._make_header(MMDBTypeID.BOOLEAN, 1 if value else 0) def _encode_pack_type(self, type_id, fmt): def pack_type(value): @@ -124,97 +246,109 @@ def pack_type(value): return pack_type - _type_decoder = None + _type_encoder = None @property - def type_decoder(self): - if self._type_decoder is None: - self._type_decoder = { - 1: self._encode_pointer, - 2: self._encode_utf8_string, - 3: self._encode_pack_type(3, '>d'), # double, - 4: self._encode_bytes, - 5: self._encode_uint(5, 2), # uint16 - 6: self._encode_uint(6, 4), # uint32 - 7: self._encode_map, - 8: self._encode_pack_type(8, '>i'), # int32 - 9: self._encode_uint(9, 8), # uint64 - 10: self._encode_uint(10, 16), # uint128 - 11: self._encode_array, - 14: self._encode_boolean, - 15: self._encode_pack_type(15, '>f'), # float, + def type_encoder(self): + if self._type_encoder is None: + self._type_encoder = { + MMDBTypeID.POINTER: self._encode_pointer, + MMDBTypeID.STRING: self._encode_utf8_string, + MMDBTypeID.DOUBLE: self._encode_pack_type(MMDBTypeID.DOUBLE, ">d"), + MMDBTypeID.BYTES: self._encode_bytes, + MMDBTypeID.UINT16: self._encode_uint(MMDBTypeID.UINT16, 2), + MMDBTypeID.UINT32: self._encode_uint(MMDBTypeID.UINT32, 4), + MMDBTypeID.MAP: self._encode_map, + MMDBTypeID.INT32: self._encode_pack_type(MMDBTypeID.INT32, ">i"), + MMDBTypeID.UINT64: self._encode_uint(MMDBTypeID.UINT64, 8), + MMDBTypeID.UINT128: self._encode_uint(MMDBTypeID.UINT128, 16), + MMDBTypeID.ARRAY: self._encode_array, + MMDBTypeID.BOOLEAN: self._encode_boolean, + MMDBTypeID.FLOAT: self._encode_pack_type(MMDBTypeID.FLOAT, ">f"), } - return self._type_decoder + return self._type_encoder def _make_header(self, type_id, length): if length >= 16843036: - raise Exception('length >= 16843036') + raise Exception("length >= 16843036") elif length >= 65821: five_bits = 31 length -= 65821 - b3 = length & 0xff - b2 = (length >> 8) & 0xff - b1 = (length >> 16) & 0xff - additional_length_bytes = struct.pack('>BBB', b1, b2, b3) + b3 = length & 0xFF + b2 = (length >> 8) & 0xFF + b1 = (length >> 16) & 0xFF + additional_length_bytes = struct.pack(">BBB", b1, b2, b3) elif length >= 285: five_bits = 30 length -= 285 - b2 = length & 0xff - b1 = (length >> 8) & 0xff - additional_length_bytes = struct.pack('>BB', b1, b2) + b2 = length & 0xFF + b1 = (length >> 8) & 0xFF + additional_length_bytes = struct.pack(">BB", b1, b2) elif length >= 29: five_bits = 29 length -= 29 - additional_length_bytes = struct.pack('>B', length & 0xff) + additional_length_bytes = struct.pack(">B", length & 0xFF) else: five_bits = length - additional_length_bytes = b'' + additional_length_bytes = b"" if type_id <= 7: - res = struct.pack('>B', (type_id << 5) + five_bits) + res = struct.pack(">B", (type_id << 5) + five_bits) else: - res = struct.pack('>BB', five_bits, type_id - 7) + res = struct.pack(">BB", five_bits, type_id - 7) return res + additional_length_bytes - _python_type_id = { - float: 15, - bool: 14, - list: 11, - dict: 7, - bytes: 4, - str: 2 - } - def python_type_id(self, value): value_type = type(value) type_id = self._python_type_id.get(value_type) if type_id: return type_id if value_type is int: - if value > 0xffffffffffffffff: - return 10 - elif value > 0xffffffff: - return 9 - elif value > 0xffff: - return 6 - elif value < 0: - return 8 + if self.int_type == "auto": + if value > UINT64_MAX: + return MMDBTypeID.UINT128 + elif value > UINT32_MAX: + return MMDBTypeID.UINT64 + elif value > UINT16_MAX: + return MMDBTypeID.UINT32 + elif value < 0: + return MMDBTypeID.INT32 + else: + return MMDBTypeID.UINT16 + elif self.int_type in ("u16", MmdbU16): + return MMDBTypeID.UINT16 + elif self.int_type in ("u32", MmdbU32): + return MMDBTypeID.UINT32 + elif self.int_type in ("u64", MmdbU64): + return MMDBTypeID.UINT64 + elif self.int_type in ("u128", MmdbU128): + return MMDBTypeID.UINT128 + elif self.int_type in ("i32", MmdbI32): + return MMDBTypeID.INT32 + elif value_type is float: + if self.float_type in ("f32", MmdbF32): + return MMDBTypeID.FLOAT else: - return 5 - if value_type is Decimal: - return 3 + return MMDBTypeID.DOUBLE + elif value_type is Decimal: + return MMDBTypeID.DOUBLE raise TypeError("unknown type {value_type}".format(value_type=value_type)) def encode_meta(self, meta): - res = self._make_header(7, len(meta)) - meta_type = {'node_count': 6, 'record_size': 5, 'ip_version': 5, - 'binary_format_major_version': 5, 'binary_format_minor_version': 5, - 'build_epoch': 9} + res = self._make_header(MMDBTypeID.MAP, len(meta)) + meta_type = { + "node_count": 6, + "record_size": 5, + "ip_version": 5, + "binary_format_major_version": 5, + "binary_format_minor_version": 5, + "build_epoch": 9, + } for k, v in list(meta.items()): # Keys are always stored by value. res += self.encode(k) @@ -232,9 +366,12 @@ def encode(self, value, type_id=None): type_id = self.python_type_id(value) try: - encoder = self.type_decoder[type_id] + encoder = self.type_encoder[type_id] except KeyError: raise ValueError("unknown type_id={type_id}".format(type_id=type_id)) + + if isinstance(value, MmdbBaseType): + value = value.value res = encoder(value) if self.cache: @@ -256,7 +393,13 @@ def encode(self, value, type_id=None): class TreeWriter(object): encoder_cls = Encoder - def __init__(self, tree, meta): + def __init__( + self, + tree: "SearchTreeNode", + meta: dict, + int_type: IntType = "auto", + float_type: FloatType = "f64", + ): self._node_idx = {} self._leaf_offset = {} self._node_list = [] @@ -266,7 +409,9 @@ def __init__(self, tree, meta): self.tree = tree self.meta = meta - self.encoder = self.encoder_cls(cache=True) + self.encoder = self.encoder_cls( + cache=True, int_type=int_type, float_type=float_type + ) @property def _data_list(self): @@ -280,7 +425,7 @@ def _build_meta(self): return { "node_count": self._node_counter, "record_size": self.record_size, - **self.meta + **self.meta, } def _adjust_record_size(self): @@ -297,7 +442,7 @@ def _adjust_record_size(self): elif bit_count <= 32: self.record_size = 32 else: - raise Exception('record_size > 32') + raise Exception("record_size > 32") self.data_offset = self.record_size * 2 / 8 * self._node_counter @@ -335,40 +480,39 @@ def _cal_node_bytes(self, node) -> bytes: right_idx = self._calc_record_idx(node.right) if self.record_size == 24: - b1 = (left_idx >> 16) & 0xff - b2 = (left_idx >> 8) & 0xff - b3 = left_idx & 0xff - b4 = (right_idx >> 16) & 0xff - b5 = (right_idx >> 8) & 0xff - b6 = right_idx & 0xff - return struct.pack('>BBBBBB', b1, b2, b3, b4, b5, b6) + b1 = (left_idx >> 16) & 0xFF + b2 = (left_idx >> 8) & 0xFF + b3 = left_idx & 0xFF + b4 = (right_idx >> 16) & 0xFF + b5 = (right_idx >> 8) & 0xFF + b6 = right_idx & 0xFF + return struct.pack(">BBBBBB", b1, b2, b3, b4, b5, b6) elif self.record_size == 28: - b1 = (left_idx >> 16) & 0xff - b2 = (left_idx >> 8) & 0xff - b3 = left_idx & 0xff - b4 = ((left_idx >> 24) & 0xf) * 16 + \ - ((right_idx >> 24) & 0xf) - b5 = (right_idx >> 16) & 0xff - b6 = (right_idx >> 8) & 0xff - b7 = right_idx & 0xff - return struct.pack('>BBBBBBB', b1, b2, b3, b4, b5, b6, b7) + b1 = (left_idx >> 16) & 0xFF + b2 = (left_idx >> 8) & 0xFF + b3 = left_idx & 0xFF + b4 = ((left_idx >> 24) & 0xF) * 16 + ((right_idx >> 24) & 0xF) + b5 = (right_idx >> 16) & 0xFF + b6 = (right_idx >> 8) & 0xFF + b7 = right_idx & 0xFF + return struct.pack(">BBBBBBB", b1, b2, b3, b4, b5, b6, b7) elif self.record_size == 32: - return struct.pack('>II', left_idx, right_idx) + return struct.pack(">II", left_idx, right_idx) else: - raise Exception('self.record_size > 32') + raise Exception("self.record_size > 32") def write(self, fname): self._enumerate_nodes(self.tree) self._adjust_record_size() - with open(fname, 'wb') as f: + with open(fname, "wb") as f: for node in self._node_list: f.write(self._cal_node_bytes(node)) - f.write(b'\x00' * 16) + f.write(b"\x00" * 16) for element in self._data_list: f.write(element) @@ -378,14 +522,35 @@ def write(self, fname): def bits_rstrip(n, length=None, keep=0): - return map(int, bin(n)[2:].rjust(length, '0')[:keep]) + return map(int, bin(n)[2:].rjust(length, "0")[:keep]) class MMDBWriter(object): - - def __init__(self, ip_version=4, database_type='GeoIP', - languages=None, description='GeoIP db', - ipv4_compatible=False): + def __init__( + self, + ip_version=4, + database_type="GeoIP", + languages: List[str] = None, + description: Union[Dict[str, str] | str] = "GeoIP db", + ipv4_compatible=False, + int_type: IntType = "auto", + float_type: FloatType = "f64", + ): + """ + Args: + ip_version (int, optional): The IP version of the database. Defaults to 4. + database_type (str, optional): The type of the database. Defaults to "GeoIP". + languages (List[str], optional): A list of languages. Defaults to []. + description (Union[Dict[str, str], str], optional): A description of the database for every language. + ipv4_compatible (bool, optional): Whether the database is compatible with IPv4. Defaults to False. + int_type (Union[str, MmdbU16, MmdbU32, MmdbU64, MmdbU128, MmdbI32], optional): The type of integer to use. Defaults to "auto". + float_type (Union[str, MmdbF32, MmdbF64], optional): The type of float to use. Defaults to "f64". + + Note: + If you want to store an IPv4 address in an IPv6 database, you should set ipv4_compatible=True. + + If you want to use a specific integer type, you can set int_type to "u16", "u32", "u64", "u128", or "i32". + """ self.tree = SearchTreeNode() self.ipv4_compatible = ipv4_compatible @@ -401,43 +566,87 @@ def __init__(self, ip_version=4, database_type='GeoIP', self._bit_length = 128 if ip_version == 6 else 32 if ip_version not in [4, 6]: - raise ValueError("ip_version should be 4 or 6, {} is incorrect".format(ip_version)) + raise ValueError( + "ip_version should be 4 or 6, {} is incorrect".format(ip_version) + ) if ip_version == 4 and ipv4_compatible: raise ValueError("ipv4_compatible=True can set when ip_version=6") if not self.binary_format_major_version: - raise ValueError("major_version can't be empty or 0: {}".format(self.binary_format_major_version)) + raise ValueError( + "major_version can't be empty or 0: {}".format( + self.binary_format_major_version + ) + ) if isinstance(description, str): self.description = {i: description for i in languages} for i in languages: if i not in self.description: raise ValueError("language {} must have description!") - def insert_network(self, network: IPSet, content: MMDBType): + self.int_type = int_type + self.float_type = float_type + + def insert_network( + self, network: IPSet, content: MMDBType, overwrite=True, python_type_id_map=None + ): + """ + Inserts a network into the MaxMind database. + + Args: + network (IPSet): The network to be inserted. It should be an instance of netaddr.IPSet. + content (MMDBType): The content associated with the network. It can be a dictionary, list, string, bytes, integer, or boolean. + overwrite (bool, optional): If True, existing network data will be overwritten. Defaults to True. + python_type_id_map: abc + + Raises: + ValueError: If the network is not an instance of netaddr.IPSet. + ValueError: If an IPv6 address is inserted into an IPv4-only database. + ValueError: If an IPv4 address is inserted into an IPv6 database without setting ipv4_compatible=True. + + Note: + This method modifies the internal tree structure of the MMDBWriter instance. + """ leaf = SearchTreeLeaf(content) if not isinstance(network, IPSet): raise ValueError("network type should be netaddr.IPSet.") network = network.iter_cidrs() for cidr in network: if self.ip_version == 4 and cidr.version == 6: - raise ValueError('You inserted a IPv6 address {} ' - 'to an IPv4-only database.'.format(cidr)) + raise ValueError( + "You inserted a IPv6 address {} " + "to an IPv4-only database.".format(cidr) + ) if self.ip_version == 6 and cidr.version == 4: if not self.ipv4_compatible: - raise ValueError("You inserted a IPv4 address {} to an IPv6 database." - "Please use ipv4_compatible=True option store " - "IPv4 address in IPv6 database as ::/96 format".format(cidr)) + raise ValueError( + "You inserted a IPv4 address {} to an IPv6 database." + "Please use ipv4_compatible=True option store " + "IPv4 address in IPv6 database as ::/96 format".format(cidr) + ) cidr = cidr.ipv6(True) node = self.tree bits = list(bits_rstrip(cidr.value, self._bit_length, cidr.prefixlen)) current_node = node supernet_leaf = None # Tracks whether we are inserting into a subnet - for (index, ip_bit) in enumerate(bits[:-1]): + for index, ip_bit in enumerate(bits[:-1]): previous_node = current_node current_node = previous_node.get_or_create(ip_bit) if isinstance(current_node, SearchTreeLeaf): - current_cidr = IPNetwork((int(''.join(map(str, bits[:index + 1])).ljust(self._bit_length, '0'), 2), index + 1)) - logger.info(f"Inserting {cidr} ({content}) into subnet of {current_cidr} ({current_node.value})") + current_cidr = IPNetwork( + ( + int( + "".join(map(str, bits[: index + 1])).ljust( + self._bit_length, "0" + ), + 2, + ), + index + 1, + ) + ) + logger.info( + f"Inserting {cidr} ({content}) into subnet of {current_cidr} ({current_node.value})" + ) supernet_leaf = current_node current_node = SearchTreeNode() previous_node[ip_bit] = current_node @@ -449,7 +658,9 @@ def insert_network(self, network: IPSet, content: MMDBType): current_node[bits[-1]] = leaf def to_db_file(self, filename: str): - return TreeWriter(self.tree, self._build_meta()).write(filename) + return TreeWriter( + self.tree, self._build_meta(), self.int_type, self.float_type + ).write(filename) def _build_meta(self): return { diff --git a/setup.py b/setup.py index c881b0f..a771eb0 100644 --- a/setup.py +++ b/setup.py @@ -10,10 +10,9 @@ def get_version(file): return re.search("__version__ = ['\"]([^'\"]+)['\"]", init_py).group(1) -version = get_version('mmdb_writer.py') -f = open(os.path.join(os.path.dirname(__file__), 'README.md')) -readme = f.read() -f.close() +version = get_version("mmdb_writer.py") +with open(os.path.join(os.path.dirname(__file__), "README.md")) as f: + readme = f.read() setup( name="mmdb_writer", @@ -21,22 +20,19 @@ def get_version(file): description="Make `mmdb` format ip library file which can be read by maxmind official language reader", long_description=readme, long_description_content_type="text/markdown", - author='VimT', - author_email='me@vimt.me', - url='https://github.com/VimT/MaxMind-DB-Writer-python', - py_modules=['mmdb_writer'], - platforms=['any'], + author="VimT", + author_email="me@vimt.me", + url="https://github.com/VimT/MaxMind-DB-Writer-python", + py_modules=["mmdb_writer"], + platforms=["any"], zip_safe=False, python_requires=">=3.6", - install_requires=['netaddr>=0.7'], - tests_require=['maxminddb>=1.5'], + install_requires=["netaddr>=0.7"], + tests_require=["maxminddb>=1.5", "numpy"], classifiers=[ - 'Development Status :: 5 - Production/Stable', - 'Programming Language :: Python', - 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.6', - 'Programming Language :: Python :: 3.7', - 'Programming Language :: Python :: 3.8', - 'Programming Language :: Python :: Implementation :: CPython', + "Development Status :: 5 - Production/Stable", + "Programming Language :: Python", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: Implementation :: CPython", ], ) diff --git a/tests/clients.py b/tests/clients.py new file mode 100644 index 0000000..e796570 --- /dev/null +++ b/tests/clients.py @@ -0,0 +1,131 @@ +import base64 +import json +import logging +import os +import subprocess +import unittest +from pathlib import Path + +import maxminddb +from netaddr.ip.sets import IPSet + +from mmdb_writer import MMDBWriter, MmdbBaseType, MmdbF32 +from tests.record import Record + +logging.basicConfig( + format="[%(asctime)s: %(levelname)s] %(message)s", level=logging.INFO +) +logger = logging.getLogger(__name__) + +BASE_DIR = Path(__file__).parent.absolute() + + +def run(command: list): + print(f"Running command: {command}") + result = subprocess.run(command, check=True, stdout=subprocess.PIPE) + return result.stdout + + +class TestClients(unittest.TestCase): + def setUp(self) -> None: + self.filepath = Path("_test.mmdb").absolute() + self.filepath.unlink(True) + self.ip = "1.1.1.1" + self.origin_data = Record.random() + self.generate_mmdb() + self.maxDiff = None + + def tearDown(self) -> None: + self.filepath.unlink(True) + + def generate_mmdb(self): + ip_version = 4 + database_type = "test_client" + languages = ["en"] + description = {"en": "for testing purposes only"} + writer = MMDBWriter( + ip_version=ip_version, + database_type=database_type, + languages=languages, + description=description, + ipv4_compatible=False, + ) + + writer.insert_network(IPSet(["1.0.0.0/8"]), self.origin_data.dict()) + + # insert other useless record + for i in range(2, 250): + info = Record.random() + writer.insert_network(IPSet([f"{i}.0.0.0/8"]), info.dict()) + + writer.to_db_file(str(self.filepath)) + + @staticmethod + def convert_bytes(d, bytes_convert, f32_convert=lambda x: float(str(x))): + def inner(d): + if isinstance(d, bytes): + return bytes_convert(d) + elif isinstance(d, dict): + return {k: inner(v) for k, v in d.items()} + elif isinstance(d, list): + return [inner(i) for i in d] + elif isinstance(d, MmdbF32): + return f32_convert(d.value) + elif isinstance(d, MmdbBaseType): + return d.value + else: + return d + + return inner(d) + + def test_python(self): + for mode in (maxminddb.MODE_MMAP_EXT, maxminddb.MODE_MMAP, maxminddb.MODE_FILE): + m = maxminddb.open_database(self.filepath, mode=mode) + python_data = m.get(self.ip) + should_data = self.origin_data.dict() + should_data = self.convert_bytes( + should_data, lambda x: bytearray(x), lambda x: float(x) + ) + self.assertDictEqual(should_data, python_data) + m.close() + + def test_java(self): + java_dir = BASE_DIR / "clients" / "java" + self.assertTrue(java_dir.exists()) + os.chdir(java_dir) + run(["mvn", "clean", "package"]) + java_data_str = run( + [ + "java", + "-jar", + "target/mmdb-test-jar-with-dependencies.jar", + "-db", + str(self.filepath), + "-ip", + self.ip, + ] + ) + java_data = json.loads(java_data_str) + should_data = self.origin_data.dict() + + # java bytes marshal as i8 list + should_data = self.convert_bytes( + should_data, lambda x: [i if i <= 127 else i - 256 for i in x] + ) + self.assertDictEqual(should_data, java_data) + + def test_go(self): + go_dir = BASE_DIR / "clients" / "go" + self.assertTrue(go_dir.exists()) + os.chdir(go_dir) + go_data_str = run( + ["go", "run", "main.go", "-db", str(self.filepath), "-ip", self.ip] + ) + go_data = json.loads(go_data_str) + + should_data = self.origin_data.dict() + # go bytes marshal as base64 str + should_data = self.convert_bytes( + should_data, lambda x: base64.b64encode(x).decode() + ) + self.assertDictEqual(should_data, go_data) diff --git a/tests/clients/go/go.mod b/tests/clients/go/go.mod new file mode 100644 index 0000000..4856216 --- /dev/null +++ b/tests/clients/go/go.mod @@ -0,0 +1,7 @@ +module mmdb-test + +go 1.22 + + require github.com/oschwald/maxminddb-golang v1.12.0 + + require golang.org/x/sys v0.10.0 // indirect diff --git a/tests/clients/go/main.go b/tests/clients/go/main.go new file mode 100644 index 0000000..984aa67 --- /dev/null +++ b/tests/clients/go/main.go @@ -0,0 +1,59 @@ +package main + +import ( + "encoding/json" + "flag" + "fmt" + "github.com/oschwald/maxminddb-golang" + "log" + "math/big" + "net" + "os" +) + +var ( + db = flag.String("db", "", "Path to the MaxMind DB file") + ip = flag.String("ip", "", "IP address to look up") +) + +type Record struct { + I32 int `json:"i32" maxminddb:"i32"` + F32 float32 `json:"f32" maxminddb:"f32"` + F64 float64 `json:"f64" maxminddb:"f64"` + U16 uint16 `json:"u16" maxminddb:"u16"` + U32 uint32 `json:"u32" maxminddb:"u32"` + U64 uint64 `json:"u64" maxminddb:"u64"` + U128 *big.Int `json:"u128" maxminddb:"u128"` + Array []any `json:"array" maxminddb:"array"` + Map map[string]any `json:"map" maxminddb:"map"` + Bytes []byte `json:"bytes" maxminddb:"bytes"` + String string `json:"string" maxminddb:"string"` + Bool bool `json:"bool" maxminddb:"bool"` +} + +func main() { + flag.Parse() + if *db == "" || *ip == "" { + flag.PrintDefaults() + os.Exit(1) + } + db, err := maxminddb.Open(*db) + if err != nil { + log.Fatal(err) + } + defer db.Close() + + ip := net.ParseIP(*ip) + + var record Record + + err = db.Lookup(ip, &record) + if err != nil { + log.Panic(err) + } + data, err := json.Marshal(record) + if err != nil { + log.Panic(err) + } + fmt.Println(string(data)) +} diff --git a/tests/clients/java/.gitignore b/tests/clients/java/.gitignore new file mode 100644 index 0000000..5ff6309 --- /dev/null +++ b/tests/clients/java/.gitignore @@ -0,0 +1,38 @@ +target/ +!.mvn/wrapper/maven-wrapper.jar +!**/src/main/**/target/ +!**/src/test/**/target/ + +### IntelliJ IDEA ### +.idea/modules.xml +.idea/jarRepositories.xml +.idea/compiler.xml +.idea/libraries/ +*.iws +*.iml +*.ipr + +### Eclipse ### +.apt_generated +.classpath +.factorypath +.project +.settings +.springBeans +.sts4-cache + +### NetBeans ### +/nbproject/private/ +/nbbuild/ +/dist/ +/nbdist/ +/.nb-gradle/ +build/ +!**/src/main/**/build/ +!**/src/test/**/build/ + +### VS Code ### +.vscode/ + +### Mac OS ### +.DS_Store \ No newline at end of file diff --git a/tests/clients/java/pom.xml b/tests/clients/java/pom.xml new file mode 100644 index 0000000..b2a0ac0 --- /dev/null +++ b/tests/clients/java/pom.xml @@ -0,0 +1,62 @@ + + + 4.0.0 + + me.vime + mmdb-test + 1.0-SNAPSHOT + + + 22 + 22 + UTF-8 + + + + com.maxmind.db + maxmind-db + 3.1.0 + + + args4j + args4j + 2.33 + + + com.google.code.gson + gson + 2.10.1 + + + + ${project.artifactId} + + + org.apache.maven.plugins + maven-assembly-plugin + + + package + + single + + + + + + Main + + + + + jar-with-dependencies + + + + + + + + \ No newline at end of file diff --git a/tests/clients/java/src/main/java/Main.java b/tests/clients/java/src/main/java/Main.java new file mode 100644 index 0000000..7bd3a90 --- /dev/null +++ b/tests/clients/java/src/main/java/Main.java @@ -0,0 +1,87 @@ +import com.google.gson.Gson; +import com.maxmind.db.MaxMindDbConstructor; +import com.maxmind.db.MaxMindDbParameter; +import com.maxmind.db.Reader; +import org.kohsuke.args4j.CmdLineParser; +import org.kohsuke.args4j.Option; + +import java.io.File; +import java.io.IOException; +import java.math.BigInteger; +import java.net.InetAddress; +import java.util.List; +import java.util.Map; + +public class Main { + @Option(name = "-db", usage = "Path to the MMDB file", required = true) + private String databasePath; + + @Option(name = "-ip", usage = "IP address to lookup", required = true) + private String ipAddress; + + public static void main(String[] args) throws Exception { + Main lookup = new Main(); + CmdLineParser parser = new CmdLineParser(lookup); + parser.parseArgument(args); + + lookup.run(); + } + + public void run() throws IOException { + File database = new File(databasePath); + Gson gson = new Gson(); + + try (Reader reader = new Reader(database)) { + InetAddress address = InetAddress.getByName(ipAddress); + + Record result = reader.get(address, Record.class); + String jsonResult = gson.toJson(result); + System.out.println(jsonResult); + } + } + + + public static class Record { + private Integer i32; + private Float f32; + private Double f64; + private Integer u16; + private Long u32; + private BigInteger u64; + private BigInteger u128; + private List array; + private Map map; + private byte[] bytes; + private String string; + private Boolean bool; + + @MaxMindDbConstructor + public Record( + @MaxMindDbParameter(name = "i32") Integer i32, + @MaxMindDbParameter(name = "f32") Float f32, + @MaxMindDbParameter(name = "f64") Double f64, + @MaxMindDbParameter(name = "u16") Integer u16, + @MaxMindDbParameter(name = "u32") Long u32, + @MaxMindDbParameter(name = "u64") BigInteger u64, + @MaxMindDbParameter(name = "u128") BigInteger u128, + @MaxMindDbParameter(name = "array") List array, + @MaxMindDbParameter(name = "map") Map map, + @MaxMindDbParameter(name = "bytes") byte[] bytes, + @MaxMindDbParameter(name = "string") String string, + @MaxMindDbParameter(name = "bool") Boolean bool + ) { + this.i32 = i32; + this.f32 = f32; + this.f64 = f64; + this.u16 = u16; + this.u32 = u32; + this.u64 = u64; + this.u128 = u128; + this.array = array; + this.map = map; + this.bytes = bytes; + this.string = string; + this.bool = bool; + } + } +} diff --git a/tests/record.py b/tests/record.py new file mode 100644 index 0000000..f7e041c --- /dev/null +++ b/tests/record.py @@ -0,0 +1,117 @@ +import random +from dataclasses import dataclass + +import numpy as np + +from mmdb_writer import ( + MmdbI32, + MmdbF32, + MmdbF64, + MmdbU16, + MmdbU32, + MmdbU64, + MmdbU128, + MmdbBaseType, +) + + +def random_str(length=10): + return "".join(random.choices("abc中文", k=length)) + + +def random_bytes(length=10): + return bytes(random.choices(range(256), k=length)) + + +def random_i32(): + return MmdbI32(random.randint(-(2**31), 0)) + + +def random_f32(): + return MmdbF32(np.float32(random.random())) + + +def random_f64(): + return MmdbF64(random.random() * 1e128) + + +def random_u16(): + return MmdbU16(random.randint(0, 2**16 - 1)) + + +def random_u32(): + return MmdbU32(random.randint(2**16, 2**32 - 1)) + + +def random_u64(): + return MmdbU64(random.randint(2**32, 2**64 - 1)) + + +def random_u128(): + return MmdbU128(random.randint(2**64, 2**128 - 1)) + + +def random_array(length=10, nested_type=False): + return [random_any(nested_type) for _ in range(length)] + + +def random_map(length=10, nested_type=False): + return {random_str(): random_any(nested_type) for _ in range(length)} + + +def random_bool(): + return random.choice([True, False]) + + +def random_any(nested_type=False): + return random.choice( + [ + random_i32, + random_f32, + random_f64, + random_u16, + random_u32, + random_u64, + random_u128, + random_bytes, + random_str, + random_bool, + *([random_array, random_map] if nested_type else []), + ] + )() + + +@dataclass +class Record(object): + i32: MmdbI32 + f32: MmdbF32 + f64: MmdbF64 + u16: MmdbU16 + u32: MmdbU32 + u64: MmdbU64 + u128: MmdbU128 + array: list + map: dict + bytes: bytes + string: str + bool: bool + + @staticmethod + def random(): + return Record( + i32=random_i32(), + f32=random_f32(), + f64=random_f64(), + u16=random_u16(), + u32=random_u32(), + u64=random_u64(), + u128=random_u128(), + array=random_array(5, True), + map=random_map(5, True), + bytes=random_bytes(), + string=random_str(), + bool=random_bool(), + ) + + def dict(self): + return self.__dict__ diff --git a/tests/test.py b/tests/test.py index d102965..bf4d1c8 100644 --- a/tests/test.py +++ b/tests/test.py @@ -1,21 +1,26 @@ # coding: utf-8 import logging import os.path +import random +import struct import unittest import maxminddb from netaddr import IPSet from mmdb_writer import MMDBWriter +from tests.record import Record -logging.basicConfig(format="[%(asctime)s: %(levelname)s] %(message)s", level=logging.INFO) -info1 = {'country': 'c1', 'isp': 'ISP1'} -info2 = {'country': 'c2', 'isp': 'ISP2'} +logging.basicConfig( + format="[%(asctime)s: %(levelname)s] %(message)s", level=logging.INFO +) +record1 = {"country": "c1", "isp": "ISP1"} +record2 = {"country": "c2", "isp": "ISP2"} class TestBuild(unittest.TestCase): def setUp(self) -> None: - self.filename = '_test.mmdb' + self.filename = "_test.mmdb" def tearDown(self) -> None: if os.path.exists(self.filename): @@ -23,12 +28,16 @@ def tearDown(self) -> None: def test_metadata(self): ip_version = 6 - database_type = 'test_database_type' - languages = ['en', 'ch'] - description = {'en': 'en test', 'ch': 'ch test'} - writer = MMDBWriter(ip_version=ip_version, database_type=database_type, - languages=languages, description=description, - ipv4_compatible=False) + database_type = "test_database_type" + languages = ["en", "ch"] + description = {"en": "en test", "ch": "ch test"} + writer = MMDBWriter( + ip_version=ip_version, + database_type=database_type, + languages=languages, + description=description, + ipv4_compatible=False, + ) writer.to_db_file(self.filename) for mode in (maxminddb.MODE_MMAP_EXT, maxminddb.MODE_MMAP, maxminddb.MODE_FILE): m = maxminddb.open_database(self.filename, mode=mode) @@ -38,43 +47,47 @@ def test_metadata(self): self.assertEqual(description, m.metadata().description, mode) m.close() - def test_encode_type(self): - writer = MMDBWriter() - info = {'int': 1, 'float': 1.0 / 3, 'list': ['a', 'b', 'c'], 'dict': {'k': 'v'}, 'bytes': b'bytes', 'str': 'str'} - writer.insert_network(IPSet(['1.1.0.0/24']), info) - writer.to_db_file(self.filename) - for mode in (maxminddb.MODE_MMAP_EXT, maxminddb.MODE_MMAP, maxminddb.MODE_FILE): - m = maxminddb.open_database(self.filename, mode=mode) - get = m.get('1.1.0.255') - self.assertEqual(len(info), len(get), mode) - self.assertEqual(info['int'], get['int'], mode) - self.assertTrue(abs(info['float'] - get['float']) < 1e-5, mode) - self.assertEqual(info['list'], get['list'], mode) - self.assertEqual(info['dict'], get['dict'], mode) - self.assertEqual(info['bytes'], get['bytes'], mode) - self.assertEqual(info['str'], get['str'], mode) - m.close() - def test_4in6(self): writer = MMDBWriter(ip_version=6, ipv4_compatible=True) - writer.insert_network(IPSet(['1.1.0.0/24']), info1) - writer.insert_network(IPSet(['fe80::/16']), info2) + writer.insert_network(IPSet(["1.1.0.0/24"]), record1) + writer.insert_network(IPSet(["fe80::/16"]), record2) writer.to_db_file(self.filename) for mode in (maxminddb.MODE_MMAP_EXT, maxminddb.MODE_MMAP, maxminddb.MODE_FILE): m = maxminddb.open_database(self.filename, mode=mode) - self.assertEqual(info1, m.get('1.1.0.1'), mode) - self.assertEqual(info2, m.get('fe80::1'), mode) + self.assertEqual(record1, m.get("1.1.0.1"), mode) + self.assertEqual(record2, m.get("fe80::1"), mode) m.close() def test_insert_subnet(self): writer = MMDBWriter() - writer.insert_network(IPSet(['1.0.0.0/8']), info1) - writer.insert_network(IPSet(['1.10.10.0/24']), info2) + writer.insert_network(IPSet(["1.0.0.0/8"]), record1) + writer.insert_network(IPSet(["1.10.10.0/24"]), record2) writer.to_db_file(self.filename) for mode in (maxminddb.MODE_MMAP_EXT, maxminddb.MODE_MMAP, maxminddb.MODE_FILE): m = maxminddb.open_database(self.filename, mode=mode) - self.assertEqual(info1, m.get('1.1.0.1'), mode) - self.assertEqual(info1, m.get('1.10.0.1'), mode) - self.assertEqual(info2, m.get('1.10.10.1'), mode) + self.assertEqual(record1, m.get("1.1.0.1"), mode) + self.assertEqual(record1, m.get("1.10.0.1"), mode) + self.assertEqual(record2, m.get("1.10.10.1"), mode) m.close() + def test_int_type(self): + value_range_map = { + "i32": (-(2 ** 31), 2 ** 31 - 1), + "u16": (0, 2 ** 16 - 1), + "u32": (0, 2 ** 32 - 1), + "u64": (0, 2 ** 64 - 1), + "u128": (0, 2 ** 128 - 1), + } + for int_type in ("i32", "u16", "u32", "u64", "u128"): + writer = MMDBWriter(int_type=int_type) + + (start, end) = value_range_map[int_type] + ok_value = random.randint(start, end) + bad_value1 = random.randint(end + 1, end + 2 ** 16) + bad_value2 = random.randint(start - 2 ** 16, start - 1) + writer.insert_network(IPSet(["1.0.0.0/8"]), {"value": ok_value}) + writer.to_db_file(self.filename) + for bad_value in (bad_value1, bad_value2): + writer.insert_network(IPSet(["1.0.0.0/8"]), {"value": bad_value}) + with self.assertRaises((ValueError, struct.error)): + writer.to_db_file(self.filename)