diff --git a/README.md b/README.md
index d111b7d..a5b88e6 100644
--- a/README.md
+++ b/README.md
@@ -26,6 +26,58 @@ assert r == {'country': 'COUNTRY', 'isp': 'ISP'}
## Examples
see [csv_to_mmdb.py](./examples/csv_to_mmdb.py)
+Here is a professional and clear translation of the README.md section from Chinese into English:
+
+## Using the Java Client
+
+### TLDR
+
+When generating an MMDB file for use with the Java client, you must specify the `int_type`:
+
+```python
+from mmdb_writer import MMDBWriter
+
+writer = MMDBWriter(int_type='int32')
+```
+
+Alternatively, you can explicitly specify data types using the [Type Enforcement](#type-enforcement) section.
+
+### Underlying Principles
+
+In Java, when deserializing to a structure, the numeric types will use the original MMDB numeric types. The specific
+conversion relationships are as follows:
+
+| mmdb type | java type |
+|--------------|------------|
+| float (15) | Float |
+| double (3) | Double |
+| int32 (8) | Integer |
+| uint16 (5) | Integer |
+| uint32 (6) | Long |
+| uint64 (9) | BigInteger |
+| uint128 (10) | BigInteger |
+
+When using the Python writer to generate an MMDB file, by default, it converts integers to the corresponding MMDB type
+based on the size of the `int`. For instance, `int(1)` would convert to `uint16`, and `int(2**16+1)` would convert
+to `uint32`. This may cause deserialization failures in Java clients. Therefore, it is necessary to specify
+the `int_type` parameter when generating MMDB files to define the numeric type accurately.
+
+## Type Enforcement
+
+MMDB supports a variety of numeric types such as `int32`, `uint16`, `uint32`, `uint64`, `uint128` for integers,
+and `f32`, `f64` for floating points, while Python only has one integer type and one float type (actually `f64`).
+
+Therefore, when generating an MMDB file, you need to specify the `int_type` parameter to define the numeric type of the
+MMDB file. The behaviors for different `int_type` settings are:
+
+| int_type | Behavior |
+|----------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
+| auto (default) | Automatically selects the MMDB numeric type based on the value size.
Rules:
`int32` for value < 0
`uint16` for 0 <= value < 2^16
`uint32` for 2^16 <= value < 2^32
`uint64` for 2^32 <= value < 2^64
`uint128` for value >= 2^64. |
+| i32 | Stores all integer types as `int32`. |
+| u16 | Stores all integer types as `uint16`. |
+| u32 | Stores all integer types as `uint32`. |
+| u64 | Stores all integer types as `uint64`. |
+| u128 | Stores all integer types as `uint128`. |
## Reference:
diff --git a/examples/csv_to_mmdb.py b/examples/csv_to_mmdb.py
index b4a4f75..28fdcd3 100644
--- a/examples/csv_to_mmdb.py
+++ b/examples/csv_to_mmdb.py
@@ -8,25 +8,30 @@
def main():
- writer = MMDBWriter(4, 'Test.GeoIP', languages=['EN'], description="Test IP library")
+ writer = MMDBWriter(
+ 4, "Test.GeoIP", languages=["EN"], description="Test IP library"
+ )
data = defaultdict(list)
# merge cidr
- with open('fake_ip_info.csv', 'r') as f:
+ with open("fake_ip_info.csv", "r") as f:
reader = csv.DictReader(f)
for line in reader:
- data[(line['country'], line['isp'])].append(IPNetwork(f'{line["ip"]}/{line["prefixlen"]}'))
+ data[(line["country"], line["isp"])].append(
+ IPNetwork(f'{line["ip"]}/{line["prefixlen"]}')
+ )
for index, cidrs in data.items():
- writer.insert_network(IPSet(cidrs), {'country': index[0], 'isp': index[1]})
- writer.to_db_file('fake_ip_library.mmdb')
+ writer.insert_network(IPSet(cidrs), {"country": index[0], "isp": index[1]})
+ writer.to_db_file("fake_ip_library.mmdb")
def test_read():
import maxminddb
- m = maxminddb.open_database('fake_ip_library.mmdb')
- r = m.get('3.1.1.1')
+
+ m = maxminddb.open_database("fake_ip_library.mmdb")
+ r = m.get("3.1.1.1")
print(r)
-if __name__ == '__main__':
+if __name__ == "__main__":
main()
diff --git a/mmdb_writer.py b/mmdb_writer.py
index 8b57540..70949bb 100644
--- a/mmdb_writer.py
+++ b/mmdb_writer.py
@@ -1,20 +1,100 @@
# coding: utf-8
-__version__ = '0.1.1'
+__version__ = "0.2.1"
import logging
import math
import struct
import time
-from typing import Union
from decimal import Decimal
+from enum import IntEnum
+from typing import Union, List, Dict, Literal
from netaddr import IPSet, IPNetwork
-MMDBType = Union[dict, list, str, bytes, int, bool]
+
+class MmdbBaseType(object):
+ def __init__(self, value):
+ self.value = value
+
+
+# type hint
+class MmdbF32(MmdbBaseType):
+ def __init__(self, value: float):
+ super().__init__(value)
+
+
+class MmdbF64(MmdbBaseType):
+ def __init__(self, value: Union[float | Decimal]):
+ super().__init__(value)
+
+
+class MmdbI32(MmdbBaseType):
+ def __init__(self, value: int):
+ super().__init__(value)
+
+
+class MmdbU16(MmdbBaseType):
+ def __init__(self, value: int):
+ super().__init__(value)
+
+
+class MmdbU32(MmdbBaseType):
+ def __init__(self, value: int):
+ super().__init__(value)
+
+
+class MmdbU64(MmdbBaseType):
+ def __init__(self, value: int):
+ super().__init__(value)
+
+
+class MmdbU128(MmdbBaseType):
+ def __init__(self, value: int):
+ super().__init__(value)
+
+
+MMDBType = Union[
+ dict,
+ list,
+ str,
+ bytes,
+ int,
+ bool,
+ MmdbF32,
+ MmdbF64,
+ MmdbI32,
+ MmdbU16,
+ MmdbU32,
+ MmdbU64,
+ MmdbU128,
+]
logger = logging.getLogger(__name__)
-METADATA_MAGIC = b'\xab\xcd\xefMaxMind.com'
+METADATA_MAGIC = b"\xab\xcd\xefMaxMind.com"
+
+
+class MMDBTypeID(IntEnum):
+ POINTER = 1
+ STRING = 2
+ DOUBLE = 3
+ BYTES = 4
+ UINT16 = 5
+ UINT32 = 6
+ MAP = 7
+ INT32 = 8
+ UINT64 = 9
+ UINT128 = 10
+ ARRAY = 11
+ DATA_CACHE = 12
+ END_MARKER = 13
+ BOOLEAN = 14
+ FLOAT = 15
+
+
+UINT16_MAX = 0xFFFF
+UINT32_MAX = 0xFFFFFFFF
+UINT64_MAX = 0xFFFFFFFFFFFFFFFF
class SearchTreeNode(object):
@@ -53,55 +133,97 @@ def __repr__(self):
__str__ = __repr__
+IntType = Union[
+ Literal["auto", "u16", "u32", "u64", "u128", "i32"]
+ | MmdbU16
+ | MmdbU32
+ | MmdbU64
+ | MmdbU128
+ | MmdbI32
+]
+FloatType = Union[Literal["f32", "f64"] | MmdbF32 | MmdbF64]
+
+
class Encoder(object):
+ def __init__(
+ self, cache=True, int_type: IntType = "auto", float_type: FloatType = "f64"
+ ):
+ self.cache = cache
+ self.int_type = int_type
+ self.float_type = float_type
- def __init__(self, cache=True):
self.data_cache = {}
self.data_list = []
self.data_pointer = 0
-
- self.cache = cache
+ self._python_type_id = {
+ float: MMDBTypeID.DOUBLE,
+ bool: MMDBTypeID.BOOLEAN,
+ list: MMDBTypeID.ARRAY,
+ dict: MMDBTypeID.MAP,
+ bytes: MMDBTypeID.BYTES,
+ str: MMDBTypeID.STRING,
+ MmdbF32: MMDBTypeID.FLOAT,
+ MmdbF64: MMDBTypeID.DOUBLE,
+ MmdbI32: MMDBTypeID.INT32,
+ MmdbU16: MMDBTypeID.UINT16,
+ MmdbU32: MMDBTypeID.UINT32,
+ MmdbU64: MMDBTypeID.UINT64,
+ MmdbU128: MMDBTypeID.UINT128,
+ }
def _encode_pointer(self, value):
pointer = value
if pointer >= 134744064:
- res = struct.pack('>BI', 0x38, pointer)
+ res = struct.pack(">BI", 0x38, pointer)
elif pointer >= 526336:
pointer -= 526336
- res = struct.pack('>BBBB', 0x30 + ((pointer >> 24) & 0x07),
- (pointer >> 16) & 0xff, (pointer >> 8) & 0xff,
- pointer & 0xff)
+ res = struct.pack(
+ ">BBBB",
+ 0x30 + ((pointer >> 24) & 0x07),
+ (pointer >> 16) & 0xFF,
+ (pointer >> 8) & 0xFF,
+ pointer & 0xFF,
+ )
elif pointer >= 2048:
pointer -= 2048
- res = struct.pack('>BBB', 0x28 + ((pointer >> 16) & 0x07),
- (pointer >> 8) & 0xff, pointer & 0xff)
+ res = struct.pack(
+ ">BBB",
+ 0x28 + ((pointer >> 16) & 0x07),
+ (pointer >> 8) & 0xFF,
+ pointer & 0xFF,
+ )
else:
- res = struct.pack('>BB', 0x20 + ((pointer >> 8) & 0x07),
- pointer & 0xff)
+ res = struct.pack(">BB", 0x20 + ((pointer >> 8) & 0x07), pointer & 0xFF)
return res
def _encode_utf8_string(self, value):
- encoded_value = value.encode('utf-8')
- res = self._make_header(2, len(encoded_value))
+ encoded_value = value.encode("utf-8")
+ res = self._make_header(MMDBTypeID.STRING, len(encoded_value))
res += encoded_value
return res
def _encode_bytes(self, value):
- return self._make_header(4, len(value)) + value
+ return self._make_header(MMDBTypeID.BYTES, len(value)) + value
def _encode_uint(self, type_id, max_len):
+ value_max = 2 ** (max_len * 8)
+
def _encode_unsigned_value(value):
- res = b''
+ if value < 0 or value >= value_max:
+ raise ValueError(
+ f"encode uint{max_len * 8} fail: {value} not in range(0, {value_max})"
+ )
+ res = b""
while value != 0 and len(res) < max_len:
- res = struct.pack('>B', value & 0xff) + res
+ res = struct.pack(">B", value & 0xFF) + res
value = value >> 8
return self._make_header(type_id, len(res)) + res
return _encode_unsigned_value
def _encode_map(self, value):
- res = self._make_header(7, len(value))
+ res = self._make_header(MMDBTypeID.MAP, len(value))
for k, v in list(value.items()):
# Keys are always stored by value.
res += self.encode(k)
@@ -109,13 +231,13 @@ def _encode_map(self, value):
return res
def _encode_array(self, value):
- res = self._make_header(11, len(value))
+ res = self._make_header(MMDBTypeID.ARRAY, len(value))
for k in value:
res += self.encode(k)
return res
def _encode_boolean(self, value):
- return self._make_header(14, 1 if value else 0)
+ return self._make_header(MMDBTypeID.BOOLEAN, 1 if value else 0)
def _encode_pack_type(self, type_id, fmt):
def pack_type(value):
@@ -124,97 +246,109 @@ def pack_type(value):
return pack_type
- _type_decoder = None
+ _type_encoder = None
@property
- def type_decoder(self):
- if self._type_decoder is None:
- self._type_decoder = {
- 1: self._encode_pointer,
- 2: self._encode_utf8_string,
- 3: self._encode_pack_type(3, '>d'), # double,
- 4: self._encode_bytes,
- 5: self._encode_uint(5, 2), # uint16
- 6: self._encode_uint(6, 4), # uint32
- 7: self._encode_map,
- 8: self._encode_pack_type(8, '>i'), # int32
- 9: self._encode_uint(9, 8), # uint64
- 10: self._encode_uint(10, 16), # uint128
- 11: self._encode_array,
- 14: self._encode_boolean,
- 15: self._encode_pack_type(15, '>f'), # float,
+ def type_encoder(self):
+ if self._type_encoder is None:
+ self._type_encoder = {
+ MMDBTypeID.POINTER: self._encode_pointer,
+ MMDBTypeID.STRING: self._encode_utf8_string,
+ MMDBTypeID.DOUBLE: self._encode_pack_type(MMDBTypeID.DOUBLE, ">d"),
+ MMDBTypeID.BYTES: self._encode_bytes,
+ MMDBTypeID.UINT16: self._encode_uint(MMDBTypeID.UINT16, 2),
+ MMDBTypeID.UINT32: self._encode_uint(MMDBTypeID.UINT32, 4),
+ MMDBTypeID.MAP: self._encode_map,
+ MMDBTypeID.INT32: self._encode_pack_type(MMDBTypeID.INT32, ">i"),
+ MMDBTypeID.UINT64: self._encode_uint(MMDBTypeID.UINT64, 8),
+ MMDBTypeID.UINT128: self._encode_uint(MMDBTypeID.UINT128, 16),
+ MMDBTypeID.ARRAY: self._encode_array,
+ MMDBTypeID.BOOLEAN: self._encode_boolean,
+ MMDBTypeID.FLOAT: self._encode_pack_type(MMDBTypeID.FLOAT, ">f"),
}
- return self._type_decoder
+ return self._type_encoder
def _make_header(self, type_id, length):
if length >= 16843036:
- raise Exception('length >= 16843036')
+ raise Exception("length >= 16843036")
elif length >= 65821:
five_bits = 31
length -= 65821
- b3 = length & 0xff
- b2 = (length >> 8) & 0xff
- b1 = (length >> 16) & 0xff
- additional_length_bytes = struct.pack('>BBB', b1, b2, b3)
+ b3 = length & 0xFF
+ b2 = (length >> 8) & 0xFF
+ b1 = (length >> 16) & 0xFF
+ additional_length_bytes = struct.pack(">BBB", b1, b2, b3)
elif length >= 285:
five_bits = 30
length -= 285
- b2 = length & 0xff
- b1 = (length >> 8) & 0xff
- additional_length_bytes = struct.pack('>BB', b1, b2)
+ b2 = length & 0xFF
+ b1 = (length >> 8) & 0xFF
+ additional_length_bytes = struct.pack(">BB", b1, b2)
elif length >= 29:
five_bits = 29
length -= 29
- additional_length_bytes = struct.pack('>B', length & 0xff)
+ additional_length_bytes = struct.pack(">B", length & 0xFF)
else:
five_bits = length
- additional_length_bytes = b''
+ additional_length_bytes = b""
if type_id <= 7:
- res = struct.pack('>B', (type_id << 5) + five_bits)
+ res = struct.pack(">B", (type_id << 5) + five_bits)
else:
- res = struct.pack('>BB', five_bits, type_id - 7)
+ res = struct.pack(">BB", five_bits, type_id - 7)
return res + additional_length_bytes
- _python_type_id = {
- float: 15,
- bool: 14,
- list: 11,
- dict: 7,
- bytes: 4,
- str: 2
- }
-
def python_type_id(self, value):
value_type = type(value)
type_id = self._python_type_id.get(value_type)
if type_id:
return type_id
if value_type is int:
- if value > 0xffffffffffffffff:
- return 10
- elif value > 0xffffffff:
- return 9
- elif value > 0xffff:
- return 6
- elif value < 0:
- return 8
+ if self.int_type == "auto":
+ if value > UINT64_MAX:
+ return MMDBTypeID.UINT128
+ elif value > UINT32_MAX:
+ return MMDBTypeID.UINT64
+ elif value > UINT16_MAX:
+ return MMDBTypeID.UINT32
+ elif value < 0:
+ return MMDBTypeID.INT32
+ else:
+ return MMDBTypeID.UINT16
+ elif self.int_type in ("u16", MmdbU16):
+ return MMDBTypeID.UINT16
+ elif self.int_type in ("u32", MmdbU32):
+ return MMDBTypeID.UINT32
+ elif self.int_type in ("u64", MmdbU64):
+ return MMDBTypeID.UINT64
+ elif self.int_type in ("u128", MmdbU128):
+ return MMDBTypeID.UINT128
+ elif self.int_type in ("i32", MmdbI32):
+ return MMDBTypeID.INT32
+ elif value_type is float:
+ if self.float_type in ("f32", MmdbF32):
+ return MMDBTypeID.FLOAT
else:
- return 5
- if value_type is Decimal:
- return 3
+ return MMDBTypeID.DOUBLE
+ elif value_type is Decimal:
+ return MMDBTypeID.DOUBLE
raise TypeError("unknown type {value_type}".format(value_type=value_type))
def encode_meta(self, meta):
- res = self._make_header(7, len(meta))
- meta_type = {'node_count': 6, 'record_size': 5, 'ip_version': 5,
- 'binary_format_major_version': 5, 'binary_format_minor_version': 5,
- 'build_epoch': 9}
+ res = self._make_header(MMDBTypeID.MAP, len(meta))
+ meta_type = {
+ "node_count": 6,
+ "record_size": 5,
+ "ip_version": 5,
+ "binary_format_major_version": 5,
+ "binary_format_minor_version": 5,
+ "build_epoch": 9,
+ }
for k, v in list(meta.items()):
# Keys are always stored by value.
res += self.encode(k)
@@ -232,9 +366,12 @@ def encode(self, value, type_id=None):
type_id = self.python_type_id(value)
try:
- encoder = self.type_decoder[type_id]
+ encoder = self.type_encoder[type_id]
except KeyError:
raise ValueError("unknown type_id={type_id}".format(type_id=type_id))
+
+ if isinstance(value, MmdbBaseType):
+ value = value.value
res = encoder(value)
if self.cache:
@@ -256,7 +393,13 @@ def encode(self, value, type_id=None):
class TreeWriter(object):
encoder_cls = Encoder
- def __init__(self, tree, meta):
+ def __init__(
+ self,
+ tree: "SearchTreeNode",
+ meta: dict,
+ int_type: IntType = "auto",
+ float_type: FloatType = "f64",
+ ):
self._node_idx = {}
self._leaf_offset = {}
self._node_list = []
@@ -266,7 +409,9 @@ def __init__(self, tree, meta):
self.tree = tree
self.meta = meta
- self.encoder = self.encoder_cls(cache=True)
+ self.encoder = self.encoder_cls(
+ cache=True, int_type=int_type, float_type=float_type
+ )
@property
def _data_list(self):
@@ -280,7 +425,7 @@ def _build_meta(self):
return {
"node_count": self._node_counter,
"record_size": self.record_size,
- **self.meta
+ **self.meta,
}
def _adjust_record_size(self):
@@ -297,7 +442,7 @@ def _adjust_record_size(self):
elif bit_count <= 32:
self.record_size = 32
else:
- raise Exception('record_size > 32')
+ raise Exception("record_size > 32")
self.data_offset = self.record_size * 2 / 8 * self._node_counter
@@ -335,40 +480,39 @@ def _cal_node_bytes(self, node) -> bytes:
right_idx = self._calc_record_idx(node.right)
if self.record_size == 24:
- b1 = (left_idx >> 16) & 0xff
- b2 = (left_idx >> 8) & 0xff
- b3 = left_idx & 0xff
- b4 = (right_idx >> 16) & 0xff
- b5 = (right_idx >> 8) & 0xff
- b6 = right_idx & 0xff
- return struct.pack('>BBBBBB', b1, b2, b3, b4, b5, b6)
+ b1 = (left_idx >> 16) & 0xFF
+ b2 = (left_idx >> 8) & 0xFF
+ b3 = left_idx & 0xFF
+ b4 = (right_idx >> 16) & 0xFF
+ b5 = (right_idx >> 8) & 0xFF
+ b6 = right_idx & 0xFF
+ return struct.pack(">BBBBBB", b1, b2, b3, b4, b5, b6)
elif self.record_size == 28:
- b1 = (left_idx >> 16) & 0xff
- b2 = (left_idx >> 8) & 0xff
- b3 = left_idx & 0xff
- b4 = ((left_idx >> 24) & 0xf) * 16 + \
- ((right_idx >> 24) & 0xf)
- b5 = (right_idx >> 16) & 0xff
- b6 = (right_idx >> 8) & 0xff
- b7 = right_idx & 0xff
- return struct.pack('>BBBBBBB', b1, b2, b3, b4, b5, b6, b7)
+ b1 = (left_idx >> 16) & 0xFF
+ b2 = (left_idx >> 8) & 0xFF
+ b3 = left_idx & 0xFF
+ b4 = ((left_idx >> 24) & 0xF) * 16 + ((right_idx >> 24) & 0xF)
+ b5 = (right_idx >> 16) & 0xFF
+ b6 = (right_idx >> 8) & 0xFF
+ b7 = right_idx & 0xFF
+ return struct.pack(">BBBBBBB", b1, b2, b3, b4, b5, b6, b7)
elif self.record_size == 32:
- return struct.pack('>II', left_idx, right_idx)
+ return struct.pack(">II", left_idx, right_idx)
else:
- raise Exception('self.record_size > 32')
+ raise Exception("self.record_size > 32")
def write(self, fname):
self._enumerate_nodes(self.tree)
self._adjust_record_size()
- with open(fname, 'wb') as f:
+ with open(fname, "wb") as f:
for node in self._node_list:
f.write(self._cal_node_bytes(node))
- f.write(b'\x00' * 16)
+ f.write(b"\x00" * 16)
for element in self._data_list:
f.write(element)
@@ -378,14 +522,35 @@ def write(self, fname):
def bits_rstrip(n, length=None, keep=0):
- return map(int, bin(n)[2:].rjust(length, '0')[:keep])
+ return map(int, bin(n)[2:].rjust(length, "0")[:keep])
class MMDBWriter(object):
-
- def __init__(self, ip_version=4, database_type='GeoIP',
- languages=None, description='GeoIP db',
- ipv4_compatible=False):
+ def __init__(
+ self,
+ ip_version=4,
+ database_type="GeoIP",
+ languages: List[str] = None,
+ description: Union[Dict[str, str] | str] = "GeoIP db",
+ ipv4_compatible=False,
+ int_type: IntType = "auto",
+ float_type: FloatType = "f64",
+ ):
+ """
+ Args:
+ ip_version (int, optional): The IP version of the database. Defaults to 4.
+ database_type (str, optional): The type of the database. Defaults to "GeoIP".
+ languages (List[str], optional): A list of languages. Defaults to [].
+ description (Union[Dict[str, str], str], optional): A description of the database for every language.
+ ipv4_compatible (bool, optional): Whether the database is compatible with IPv4. Defaults to False.
+ int_type (Union[str, MmdbU16, MmdbU32, MmdbU64, MmdbU128, MmdbI32], optional): The type of integer to use. Defaults to "auto".
+ float_type (Union[str, MmdbF32, MmdbF64], optional): The type of float to use. Defaults to "f64".
+
+ Note:
+ If you want to store an IPv4 address in an IPv6 database, you should set ipv4_compatible=True.
+
+ If you want to use a specific integer type, you can set int_type to "u16", "u32", "u64", "u128", or "i32".
+ """
self.tree = SearchTreeNode()
self.ipv4_compatible = ipv4_compatible
@@ -401,43 +566,87 @@ def __init__(self, ip_version=4, database_type='GeoIP',
self._bit_length = 128 if ip_version == 6 else 32
if ip_version not in [4, 6]:
- raise ValueError("ip_version should be 4 or 6, {} is incorrect".format(ip_version))
+ raise ValueError(
+ "ip_version should be 4 or 6, {} is incorrect".format(ip_version)
+ )
if ip_version == 4 and ipv4_compatible:
raise ValueError("ipv4_compatible=True can set when ip_version=6")
if not self.binary_format_major_version:
- raise ValueError("major_version can't be empty or 0: {}".format(self.binary_format_major_version))
+ raise ValueError(
+ "major_version can't be empty or 0: {}".format(
+ self.binary_format_major_version
+ )
+ )
if isinstance(description, str):
self.description = {i: description for i in languages}
for i in languages:
if i not in self.description:
raise ValueError("language {} must have description!")
- def insert_network(self, network: IPSet, content: MMDBType):
+ self.int_type = int_type
+ self.float_type = float_type
+
+ def insert_network(
+ self, network: IPSet, content: MMDBType, overwrite=True, python_type_id_map=None
+ ):
+ """
+ Inserts a network into the MaxMind database.
+
+ Args:
+ network (IPSet): The network to be inserted. It should be an instance of netaddr.IPSet.
+ content (MMDBType): The content associated with the network. It can be a dictionary, list, string, bytes, integer, or boolean.
+ overwrite (bool, optional): If True, existing network data will be overwritten. Defaults to True.
+ python_type_id_map: abc
+
+ Raises:
+ ValueError: If the network is not an instance of netaddr.IPSet.
+ ValueError: If an IPv6 address is inserted into an IPv4-only database.
+ ValueError: If an IPv4 address is inserted into an IPv6 database without setting ipv4_compatible=True.
+
+ Note:
+ This method modifies the internal tree structure of the MMDBWriter instance.
+ """
leaf = SearchTreeLeaf(content)
if not isinstance(network, IPSet):
raise ValueError("network type should be netaddr.IPSet.")
network = network.iter_cidrs()
for cidr in network:
if self.ip_version == 4 and cidr.version == 6:
- raise ValueError('You inserted a IPv6 address {} '
- 'to an IPv4-only database.'.format(cidr))
+ raise ValueError(
+ "You inserted a IPv6 address {} "
+ "to an IPv4-only database.".format(cidr)
+ )
if self.ip_version == 6 and cidr.version == 4:
if not self.ipv4_compatible:
- raise ValueError("You inserted a IPv4 address {} to an IPv6 database."
- "Please use ipv4_compatible=True option store "
- "IPv4 address in IPv6 database as ::/96 format".format(cidr))
+ raise ValueError(
+ "You inserted a IPv4 address {} to an IPv6 database."
+ "Please use ipv4_compatible=True option store "
+ "IPv4 address in IPv6 database as ::/96 format".format(cidr)
+ )
cidr = cidr.ipv6(True)
node = self.tree
bits = list(bits_rstrip(cidr.value, self._bit_length, cidr.prefixlen))
current_node = node
supernet_leaf = None # Tracks whether we are inserting into a subnet
- for (index, ip_bit) in enumerate(bits[:-1]):
+ for index, ip_bit in enumerate(bits[:-1]):
previous_node = current_node
current_node = previous_node.get_or_create(ip_bit)
if isinstance(current_node, SearchTreeLeaf):
- current_cidr = IPNetwork((int(''.join(map(str, bits[:index + 1])).ljust(self._bit_length, '0'), 2), index + 1))
- logger.info(f"Inserting {cidr} ({content}) into subnet of {current_cidr} ({current_node.value})")
+ current_cidr = IPNetwork(
+ (
+ int(
+ "".join(map(str, bits[: index + 1])).ljust(
+ self._bit_length, "0"
+ ),
+ 2,
+ ),
+ index + 1,
+ )
+ )
+ logger.info(
+ f"Inserting {cidr} ({content}) into subnet of {current_cidr} ({current_node.value})"
+ )
supernet_leaf = current_node
current_node = SearchTreeNode()
previous_node[ip_bit] = current_node
@@ -449,7 +658,9 @@ def insert_network(self, network: IPSet, content: MMDBType):
current_node[bits[-1]] = leaf
def to_db_file(self, filename: str):
- return TreeWriter(self.tree, self._build_meta()).write(filename)
+ return TreeWriter(
+ self.tree, self._build_meta(), self.int_type, self.float_type
+ ).write(filename)
def _build_meta(self):
return {
diff --git a/setup.py b/setup.py
index c881b0f..a771eb0 100644
--- a/setup.py
+++ b/setup.py
@@ -10,10 +10,9 @@ def get_version(file):
return re.search("__version__ = ['\"]([^'\"]+)['\"]", init_py).group(1)
-version = get_version('mmdb_writer.py')
-f = open(os.path.join(os.path.dirname(__file__), 'README.md'))
-readme = f.read()
-f.close()
+version = get_version("mmdb_writer.py")
+with open(os.path.join(os.path.dirname(__file__), "README.md")) as f:
+ readme = f.read()
setup(
name="mmdb_writer",
@@ -21,22 +20,19 @@ def get_version(file):
description="Make `mmdb` format ip library file which can be read by maxmind official language reader",
long_description=readme,
long_description_content_type="text/markdown",
- author='VimT',
- author_email='me@vimt.me',
- url='https://github.com/VimT/MaxMind-DB-Writer-python',
- py_modules=['mmdb_writer'],
- platforms=['any'],
+ author="VimT",
+ author_email="me@vimt.me",
+ url="https://github.com/VimT/MaxMind-DB-Writer-python",
+ py_modules=["mmdb_writer"],
+ platforms=["any"],
zip_safe=False,
python_requires=">=3.6",
- install_requires=['netaddr>=0.7'],
- tests_require=['maxminddb>=1.5'],
+ install_requires=["netaddr>=0.7"],
+ tests_require=["maxminddb>=1.5", "numpy"],
classifiers=[
- 'Development Status :: 5 - Production/Stable',
- 'Programming Language :: Python',
- 'Programming Language :: Python :: 3',
- 'Programming Language :: Python :: 3.6',
- 'Programming Language :: Python :: 3.7',
- 'Programming Language :: Python :: 3.8',
- 'Programming Language :: Python :: Implementation :: CPython',
+ "Development Status :: 5 - Production/Stable",
+ "Programming Language :: Python",
+ "Programming Language :: Python :: 3",
+ "Programming Language :: Python :: Implementation :: CPython",
],
)
diff --git a/tests/clients.py b/tests/clients.py
new file mode 100644
index 0000000..e796570
--- /dev/null
+++ b/tests/clients.py
@@ -0,0 +1,131 @@
+import base64
+import json
+import logging
+import os
+import subprocess
+import unittest
+from pathlib import Path
+
+import maxminddb
+from netaddr.ip.sets import IPSet
+
+from mmdb_writer import MMDBWriter, MmdbBaseType, MmdbF32
+from tests.record import Record
+
+logging.basicConfig(
+ format="[%(asctime)s: %(levelname)s] %(message)s", level=logging.INFO
+)
+logger = logging.getLogger(__name__)
+
+BASE_DIR = Path(__file__).parent.absolute()
+
+
+def run(command: list):
+ print(f"Running command: {command}")
+ result = subprocess.run(command, check=True, stdout=subprocess.PIPE)
+ return result.stdout
+
+
+class TestClients(unittest.TestCase):
+ def setUp(self) -> None:
+ self.filepath = Path("_test.mmdb").absolute()
+ self.filepath.unlink(True)
+ self.ip = "1.1.1.1"
+ self.origin_data = Record.random()
+ self.generate_mmdb()
+ self.maxDiff = None
+
+ def tearDown(self) -> None:
+ self.filepath.unlink(True)
+
+ def generate_mmdb(self):
+ ip_version = 4
+ database_type = "test_client"
+ languages = ["en"]
+ description = {"en": "for testing purposes only"}
+ writer = MMDBWriter(
+ ip_version=ip_version,
+ database_type=database_type,
+ languages=languages,
+ description=description,
+ ipv4_compatible=False,
+ )
+
+ writer.insert_network(IPSet(["1.0.0.0/8"]), self.origin_data.dict())
+
+ # insert other useless record
+ for i in range(2, 250):
+ info = Record.random()
+ writer.insert_network(IPSet([f"{i}.0.0.0/8"]), info.dict())
+
+ writer.to_db_file(str(self.filepath))
+
+ @staticmethod
+ def convert_bytes(d, bytes_convert, f32_convert=lambda x: float(str(x))):
+ def inner(d):
+ if isinstance(d, bytes):
+ return bytes_convert(d)
+ elif isinstance(d, dict):
+ return {k: inner(v) for k, v in d.items()}
+ elif isinstance(d, list):
+ return [inner(i) for i in d]
+ elif isinstance(d, MmdbF32):
+ return f32_convert(d.value)
+ elif isinstance(d, MmdbBaseType):
+ return d.value
+ else:
+ return d
+
+ return inner(d)
+
+ def test_python(self):
+ for mode in (maxminddb.MODE_MMAP_EXT, maxminddb.MODE_MMAP, maxminddb.MODE_FILE):
+ m = maxminddb.open_database(self.filepath, mode=mode)
+ python_data = m.get(self.ip)
+ should_data = self.origin_data.dict()
+ should_data = self.convert_bytes(
+ should_data, lambda x: bytearray(x), lambda x: float(x)
+ )
+ self.assertDictEqual(should_data, python_data)
+ m.close()
+
+ def test_java(self):
+ java_dir = BASE_DIR / "clients" / "java"
+ self.assertTrue(java_dir.exists())
+ os.chdir(java_dir)
+ run(["mvn", "clean", "package"])
+ java_data_str = run(
+ [
+ "java",
+ "-jar",
+ "target/mmdb-test-jar-with-dependencies.jar",
+ "-db",
+ str(self.filepath),
+ "-ip",
+ self.ip,
+ ]
+ )
+ java_data = json.loads(java_data_str)
+ should_data = self.origin_data.dict()
+
+ # java bytes marshal as i8 list
+ should_data = self.convert_bytes(
+ should_data, lambda x: [i if i <= 127 else i - 256 for i in x]
+ )
+ self.assertDictEqual(should_data, java_data)
+
+ def test_go(self):
+ go_dir = BASE_DIR / "clients" / "go"
+ self.assertTrue(go_dir.exists())
+ os.chdir(go_dir)
+ go_data_str = run(
+ ["go", "run", "main.go", "-db", str(self.filepath), "-ip", self.ip]
+ )
+ go_data = json.loads(go_data_str)
+
+ should_data = self.origin_data.dict()
+ # go bytes marshal as base64 str
+ should_data = self.convert_bytes(
+ should_data, lambda x: base64.b64encode(x).decode()
+ )
+ self.assertDictEqual(should_data, go_data)
diff --git a/tests/clients/go/go.mod b/tests/clients/go/go.mod
new file mode 100644
index 0000000..4856216
--- /dev/null
+++ b/tests/clients/go/go.mod
@@ -0,0 +1,7 @@
+module mmdb-test
+
+go 1.22
+
+ require github.com/oschwald/maxminddb-golang v1.12.0
+
+ require golang.org/x/sys v0.10.0 // indirect
diff --git a/tests/clients/go/main.go b/tests/clients/go/main.go
new file mode 100644
index 0000000..984aa67
--- /dev/null
+++ b/tests/clients/go/main.go
@@ -0,0 +1,59 @@
+package main
+
+import (
+ "encoding/json"
+ "flag"
+ "fmt"
+ "github.com/oschwald/maxminddb-golang"
+ "log"
+ "math/big"
+ "net"
+ "os"
+)
+
+var (
+ db = flag.String("db", "", "Path to the MaxMind DB file")
+ ip = flag.String("ip", "", "IP address to look up")
+)
+
+type Record struct {
+ I32 int `json:"i32" maxminddb:"i32"`
+ F32 float32 `json:"f32" maxminddb:"f32"`
+ F64 float64 `json:"f64" maxminddb:"f64"`
+ U16 uint16 `json:"u16" maxminddb:"u16"`
+ U32 uint32 `json:"u32" maxminddb:"u32"`
+ U64 uint64 `json:"u64" maxminddb:"u64"`
+ U128 *big.Int `json:"u128" maxminddb:"u128"`
+ Array []any `json:"array" maxminddb:"array"`
+ Map map[string]any `json:"map" maxminddb:"map"`
+ Bytes []byte `json:"bytes" maxminddb:"bytes"`
+ String string `json:"string" maxminddb:"string"`
+ Bool bool `json:"bool" maxminddb:"bool"`
+}
+
+func main() {
+ flag.Parse()
+ if *db == "" || *ip == "" {
+ flag.PrintDefaults()
+ os.Exit(1)
+ }
+ db, err := maxminddb.Open(*db)
+ if err != nil {
+ log.Fatal(err)
+ }
+ defer db.Close()
+
+ ip := net.ParseIP(*ip)
+
+ var record Record
+
+ err = db.Lookup(ip, &record)
+ if err != nil {
+ log.Panic(err)
+ }
+ data, err := json.Marshal(record)
+ if err != nil {
+ log.Panic(err)
+ }
+ fmt.Println(string(data))
+}
diff --git a/tests/clients/java/.gitignore b/tests/clients/java/.gitignore
new file mode 100644
index 0000000..5ff6309
--- /dev/null
+++ b/tests/clients/java/.gitignore
@@ -0,0 +1,38 @@
+target/
+!.mvn/wrapper/maven-wrapper.jar
+!**/src/main/**/target/
+!**/src/test/**/target/
+
+### IntelliJ IDEA ###
+.idea/modules.xml
+.idea/jarRepositories.xml
+.idea/compiler.xml
+.idea/libraries/
+*.iws
+*.iml
+*.ipr
+
+### Eclipse ###
+.apt_generated
+.classpath
+.factorypath
+.project
+.settings
+.springBeans
+.sts4-cache
+
+### NetBeans ###
+/nbproject/private/
+/nbbuild/
+/dist/
+/nbdist/
+/.nb-gradle/
+build/
+!**/src/main/**/build/
+!**/src/test/**/build/
+
+### VS Code ###
+.vscode/
+
+### Mac OS ###
+.DS_Store
\ No newline at end of file
diff --git a/tests/clients/java/pom.xml b/tests/clients/java/pom.xml
new file mode 100644
index 0000000..b2a0ac0
--- /dev/null
+++ b/tests/clients/java/pom.xml
@@ -0,0 +1,62 @@
+
+
+ 4.0.0
+
+ me.vime
+ mmdb-test
+ 1.0-SNAPSHOT
+
+
+ 22
+ 22
+ UTF-8
+
+
+
+ com.maxmind.db
+ maxmind-db
+ 3.1.0
+
+
+ args4j
+ args4j
+ 2.33
+
+
+ com.google.code.gson
+ gson
+ 2.10.1
+
+
+
+ ${project.artifactId}
+
+
+ org.apache.maven.plugins
+ maven-assembly-plugin
+
+
+ package
+
+ single
+
+
+
+
+
+ Main
+
+
+
+
+ jar-with-dependencies
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/tests/clients/java/src/main/java/Main.java b/tests/clients/java/src/main/java/Main.java
new file mode 100644
index 0000000..7bd3a90
--- /dev/null
+++ b/tests/clients/java/src/main/java/Main.java
@@ -0,0 +1,87 @@
+import com.google.gson.Gson;
+import com.maxmind.db.MaxMindDbConstructor;
+import com.maxmind.db.MaxMindDbParameter;
+import com.maxmind.db.Reader;
+import org.kohsuke.args4j.CmdLineParser;
+import org.kohsuke.args4j.Option;
+
+import java.io.File;
+import java.io.IOException;
+import java.math.BigInteger;
+import java.net.InetAddress;
+import java.util.List;
+import java.util.Map;
+
+public class Main {
+ @Option(name = "-db", usage = "Path to the MMDB file", required = true)
+ private String databasePath;
+
+ @Option(name = "-ip", usage = "IP address to lookup", required = true)
+ private String ipAddress;
+
+ public static void main(String[] args) throws Exception {
+ Main lookup = new Main();
+ CmdLineParser parser = new CmdLineParser(lookup);
+ parser.parseArgument(args);
+
+ lookup.run();
+ }
+
+ public void run() throws IOException {
+ File database = new File(databasePath);
+ Gson gson = new Gson();
+
+ try (Reader reader = new Reader(database)) {
+ InetAddress address = InetAddress.getByName(ipAddress);
+
+ Record result = reader.get(address, Record.class);
+ String jsonResult = gson.toJson(result);
+ System.out.println(jsonResult);
+ }
+ }
+
+
+ public static class Record {
+ private Integer i32;
+ private Float f32;
+ private Double f64;
+ private Integer u16;
+ private Long u32;
+ private BigInteger u64;
+ private BigInteger u128;
+ private List