diff --git a/megatron/core/transformer/moe/grouped_gemm_util.py b/megatron/core/transformer/moe/grouped_gemm_util.py index e7ef79d795..409244de7c 100644 --- a/megatron/core/transformer/moe/grouped_gemm_util.py +++ b/megatron/core/transformer/moe/grouped_gemm_util.py @@ -1,5 +1,9 @@ # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +from importlib.metadata import version + +from pkg_resources import packaging + try: import grouped_gemm except ImportError: @@ -13,7 +17,13 @@ def grouped_gemm_is_available(): def assert_grouped_gemm_is_available(): assert grouped_gemm_is_available(), ( "Grouped GEMM is not available. Please run " - "`pip install git+https://github.com/fanshiqing/grouped_gemm@v1.0`." + "`pip install git+https://github.com/fanshiqing/grouped_gemm@v1.1.2`." + ) + + _gg_version = packaging.version.Version(version("grouped_gemm")) + assert _gg_version >= packaging.version.Version("1.1.2"), ( + "Grouped GEMM should be v1.1.2 or newer. Please run " + "`pip install git+https://github.com/fanshiqing/grouped_gemm@v1.1.2`." )