|
78 | 78 | "base_decode", |
79 | 79 | "gen_id", |
80 | 80 | "timeti", |
| 81 | + "SnowFlake", |
81 | 82 | ] |
82 | 83 |
|
83 | 84 |
|
@@ -1892,6 +1893,160 @@ def timeti( |
1892 | 1893 | return int(1 / (result / number)) |
1893 | 1894 |
|
1894 | 1895 |
|
| 1896 | +class SnowFlake: |
| 1897 | + def __init__( |
| 1898 | + self, |
| 1899 | + machine_id: int = 1, |
| 1900 | + worker_id: int = 1, |
| 1901 | + wait_for_next_ms=True, |
| 1902 | + start_epoch_ms: int = -1, |
| 1903 | + *, |
| 1904 | + max_bits=64, |
| 1905 | + sign_bit=1, |
| 1906 | + machine_id_bit=5, # between 0 and 31 |
| 1907 | + worker_id_bit=5, # between 0 and 31 |
| 1908 | + seq_bit=12, # between 0 and 4095 |
| 1909 | + start_date="2025-01-01", |
| 1910 | + ): |
| 1911 | + r"""Generate unique IDs using Twitter's Snowflake algorithm. |
| 1912 | +
|
| 1913 | + The ID is composed of: |
| 1914 | + - 41 bits timestamp in milliseconds since a custom epoch |
| 1915 | + - 5 bits machine ID |
| 1916 | + - 5 bits worker ID |
| 1917 | + - 12 bits sequence number |
| 1918 | + If you need to ensure thread safety, please lock it yourself. |
| 1919 | +
|
| 1920 | + Args: |
| 1921 | + machine_id (int): the machine_id of the SnowFlake object |
| 1922 | + worker_id (int): the worker_id of the SnowFlake object |
| 1923 | + wait_for_next_ms (bool, optional): whether to wait for next millisecond if sequence overflows. Defaults to True. |
| 1924 | + start_epoch_ms (int, optional): the start epoch in milliseconds. Defaults to -1. |
| 1925 | + * |
| 1926 | + max_bits (int, optional): the maximum bits of the ID. Defaults to 64. |
| 1927 | + sign_bit (int, optional): the sign bit of the ID. Defaults to 1. |
| 1928 | + machine_id_bit (int, optional): the machine_id bit of the ID. Defaults to 5. |
| 1929 | + worker_id_bit (int, optional): the worker_id bit of the ID. Defaults to 5. |
| 1930 | + seq_bit (int, optional): the sequence bit of the ID. Defaults to 12. |
| 1931 | + start_date (str): the start date of the SnowFlake object. Defaults to "2025-01-01", only used when start_epoch_ms is -1. |
| 1932 | +
|
| 1933 | + Example: |
| 1934 | + >>> import time |
| 1935 | + >>> snowflake = SnowFlake() |
| 1936 | + >>> ids = [snowflake.get_id() for _ in range(10000)] |
| 1937 | + >>> len(set(ids)) == len(ids) |
| 1938 | + True |
| 1939 | + >>> # test timestamp overflow |
| 1940 | + >>> snowflake = SnowFlake(1, 1, start_date=time.strftime("%Y-%m-%d")) |
| 1941 | + >>> timeleft = snowflake.timestamp_overflow_check() // 1000 // 60 // 60 // 24 // 365 |
| 1942 | + >>> timeleft == 69 |
| 1943 | + True |
| 1944 | + >>> snowflake = SnowFlake(1, 1, start_date=time.strftime("%Y-%m-%d"), sign_bit=0) |
| 1945 | + >>> timeleft = snowflake.timestamp_overflow_check() // 1000 // 60 // 60 // 24 // 365 |
| 1946 | + >>> timeleft >= 138 |
| 1947 | + True |
| 1948 | + >>> # test machine_id and worker_id overflow |
| 1949 | + >>> try: |
| 1950 | + ... snowflake = SnowFlake(32, 32) |
| 1951 | + ... except ValueError as e: |
| 1952 | + ... e |
| 1953 | + ValueError('Machine ID must be between 0 and 31') |
| 1954 | + >>> sf = SnowFlake(32, 32, machine_id_bit=6, worker_id_bit=6) |
| 1955 | + >>> sf.max_machine_id |
| 1956 | + 63 |
| 1957 | + >>> sf = SnowFlake(32, machine_id_bit=64) |
| 1958 | + >>> sf.timestamp_overflow_check() < 0 |
| 1959 | + True |
| 1960 | + """ |
| 1961 | + self.max_bits = max_bits |
| 1962 | + self.sign_bit = sign_bit |
| 1963 | + self.machine_id_bit = machine_id_bit |
| 1964 | + self.worker_id_bit = worker_id_bit |
| 1965 | + self.seq_bit = seq_bit |
| 1966 | + self.max_seq = (1 << self.seq_bit) - 1 |
| 1967 | + self.max_worker_id = (1 << self.worker_id_bit) - 1 |
| 1968 | + self.max_machine_id = (1 << self.machine_id_bit) - 1 |
| 1969 | + self.max_timestamp = self.get_max_timestamp() |
| 1970 | + # Validate inputs |
| 1971 | + if not 0 <= machine_id <= self.max_machine_id: |
| 1972 | + raise ValueError(f"Machine ID must be between 0 and {self.max_machine_id}") |
| 1973 | + if not 0 <= worker_id <= self.max_worker_id: |
| 1974 | + raise ValueError(f"Worker ID must be between 0 and {self.max_worker_id}") |
| 1975 | + self.timestamp_shift = self.seq_bit + self.worker_id_bit + self.machine_id_bit |
| 1976 | + self.machine_id = machine_id |
| 1977 | + self.worker_id = worker_id |
| 1978 | + # Calculate parts of ID |
| 1979 | + self.machine_id_part = machine_id << (self.seq_bit + self.worker_id_bit) |
| 1980 | + self.worker_id_part = worker_id << self.seq_bit |
| 1981 | + # Initialize sequence, timestamp and last timestamp |
| 1982 | + self.seq = 0 |
| 1983 | + self.last_timestamp = -1 |
| 1984 | + if start_epoch_ms < 0: |
| 1985 | + self.start_epoch_ms = self.str_to_ms(start_date) |
| 1986 | + else: |
| 1987 | + self.start_epoch_ms = start_epoch_ms |
| 1988 | + self.wait_for_next_ms = wait_for_next_ms |
| 1989 | + |
| 1990 | + def get_max_timestamp(self): |
| 1991 | + """Get maximum timestamp that can be represented""" |
| 1992 | + return (1 << (self.max_bits - self.sign_bit)) - 1 >> ( |
| 1993 | + self.seq_bit + self.worker_id_bit + self.machine_id_bit |
| 1994 | + ) |
| 1995 | + |
| 1996 | + def timestamp_overflow_check(self): |
| 1997 | + """Check how many milliseconds left until timestamp overflows""" |
| 1998 | + return self.get_max_timestamp() - self._ms_passed() |
| 1999 | + |
| 2000 | + @staticmethod |
| 2001 | + def str_to_ms(string: str) -> int: |
| 2002 | + """Convert string to milliseconds since start_time""" |
| 2003 | + return int(mktime(strptime(string, "%Y-%m-%d")) * 1000) |
| 2004 | + |
| 2005 | + def _ms_passed(self): |
| 2006 | + """Get current timestamp in milliseconds since start_time""" |
| 2007 | + return int(time() * 1000 - self.start_epoch_ms) |
| 2008 | + |
| 2009 | + def _wait_next_millis(self, last_timestamp): |
| 2010 | + """Wait until next millisecond""" |
| 2011 | + if not self.wait_for_next_ms: |
| 2012 | + raise RuntimeError( |
| 2013 | + f"Over {self.max_seq} IDs generated in 1ms, increase the wait_for_next_ms parameter" |
| 2014 | + ) |
| 2015 | + timestamp = self._ms_passed() |
| 2016 | + while timestamp <= last_timestamp: |
| 2017 | + timestamp = self._ms_passed() |
| 2018 | + return timestamp |
| 2019 | + |
| 2020 | + def get_id(self): |
| 2021 | + """Generate next unique ID""" |
| 2022 | + now = self._ms_passed() |
| 2023 | + |
| 2024 | + # Clock moved backwards, reject requests |
| 2025 | + if now < self.last_timestamp: |
| 2026 | + raise RuntimeError( |
| 2027 | + f"Clock moved backwards. Refusing to generate ID for {self.last_timestamp - now} milliseconds" |
| 2028 | + ) |
| 2029 | + |
| 2030 | + # Same timestamp, increment sequence |
| 2031 | + if now == self.last_timestamp: |
| 2032 | + self.seq = (self.seq + 1) & self.max_seq |
| 2033 | + # Sequence overflow, wait for next millisecond |
| 2034 | + if self.seq == 0: |
| 2035 | + now = self._wait_next_millis(now) |
| 2036 | + else: |
| 2037 | + # Reset sequence for different timestamp |
| 2038 | + self.seq = 0 |
| 2039 | + |
| 2040 | + self.last_timestamp = now |
| 2041 | + # Compose ID from components |
| 2042 | + return ( |
| 2043 | + (now << self.timestamp_shift) |
| 2044 | + | (self.machine_id_part) |
| 2045 | + | (self.worker_id_part) |
| 2046 | + | self.seq |
| 2047 | + ) |
| 2048 | + |
| 2049 | + |
1895 | 2050 | if __name__ == "__main__": |
1896 | 2051 | __name__ = "morebuiltins.utils" |
1897 | 2052 | import doctest |
|
0 commit comments