-
Notifications
You must be signed in to change notification settings - Fork 355
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
using nccl ops from TRT-LLM namespace #3250
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/examples/distributed_inference/tensor_parallel_simple_example.py 2024-10-19 00:55:11.232553+00:00
+++ /home/runner/work/TensorRT/TensorRT/examples/distributed_inference/tensor_parallel_simple_example.py 2024-10-19 00:55:32.513756+00:00
@@ -84,11 +84,11 @@
ctypes.CDLL(plugin_lib_path)
logger.info(f"plugin loaded successfully")
except OSError as e:
logger.info(f"unsuccessful load : {e}")
trt.init_libnvinfer_plugins(None, "")
-#Iterate over all registered plugin creators
+# Iterate over all registered plugin creators
plugin_registry = trt.get_plugin_registry()
for plugin_creator in plugin_registry.plugin_creator_list:
logger.info(
f"Plugin Name: {plugin_creator.name}, Namespace: {plugin_creator.plugin_namespace}, Version: {plugin_creator.plugin_version}"
)
c916bf6
to
195b1c4
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/examples/distributed_inference/tensor_parallel_simple_example.py 2024-10-21 20:25:45.697459+00:00
+++ /home/runner/work/TensorRT/TensorRT/examples/distributed_inference/tensor_parallel_simple_example.py 2024-10-21 20:26:10.941910+00:00
@@ -26,44 +26,51 @@
)
import tensorrt as trt
import tensorrt_llm
import ctypes
import logging
+
"""
This example copies some code from https://github.com/pytorch/examples/blob/main/distributed/tensor_parallelism/tensor_parallel_example.py
"""
plugin_lib_path = "/root/.pyenv/versions/3.10.14/lib/python3.10/site-packages/tensorrt_llm/libs/libnvinfer_plugin_tensorrt_llm.so"
try:
- ctypes.CDLL("/root/.pyenv/versions/3.10.14/lib/python3.10/site-packages/tensorrt_llm/libs/libnvinfer_plugin_tensorrt_llm.so")
+ ctypes.CDLL(
+ "/root/.pyenv/versions/3.10.14/lib/python3.10/site-packages/tensorrt_llm/libs/libnvinfer_plugin_tensorrt_llm.so"
+ )
print("plugin loaded sucessfully")
except OSError as e:
print(f"unsuccessful load : {e}")
logger = trt.Logger(trt.Logger.VERBOSE)
-trt.init_libnvinfer_plugins(None, '')
-#-[p;Iterate over all registered plugin creators
+trt.init_libnvinfer_plugins(None, "")
+# -[p;Iterate over all registered plugin creators
plugin_registry = trt.get_plugin_registry()
for plugin_creator in plugin_registry.plugin_creator_list:
- print(f"Plugin Name: {plugin_creator.name}, Namespace: {plugin_creator.plugin_namespace}, Version: {plugin_creator.plugin_version}")
+ print(
+ f"Plugin Name: {plugin_creator.name}, Namespace: {plugin_creator.plugin_namespace}, Version: {plugin_creator.plugin_version}"
+ )
@dynamo_tensorrt_converter(torch.ops._c10d_functional.all_gather_into_tensor.default)
def insert_gather_op(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
- name: str,
+ name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
plug_inputs = [args[0]]
allgather_plg_creator = trt.get_plugin_registry().get_plugin_creator(
"AllGather", "1", "tensorrt_llm"
)
assert allgather_plg_creator is not None
world_size = dist.get_world_size()
group = list(range(world_size))
- group = trt.PluginField("group", np.array(group, dtype=np.int32), trt.PluginFieldType.INT32)
+ group = trt.PluginField(
+ "group", np.array(group, dtype=np.int32), trt.PluginFieldType.INT32
+ )
p_dtype = trt.float16
pf_type = trt.PluginField(
"type_id", np.array([int(p_dtype)], np.int32), trt.PluginFieldType.INT32
)
pfc = trt.PluginFieldCollection([group, pf_type])
8015490
to
a27b719
Compare
logger.info(f"plugin loaded successfully") | ||
except OSError as e: | ||
logger.info(f"unsuccessful load : {e}") | ||
trt.init_libnvinfer_plugins(None, "") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you need these lines as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think these lines are required actually. Just tested the code without these lines and having "import tensorrt_llm" should be fine to have the plugins with namespace as tensorrt_llm to be loaded.
logger.info(f"unsuccessful load : {e}") | ||
trt.init_libnvinfer_plugins(None, "") | ||
# Iterate over all registered plugin creators | ||
plugin_registry = trt.get_plugin_registry() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this just for debugging purposes?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes to see if the the plugins with "tensorrt_llm" namespace have been loaded properly or not
"AllGather", "1", "tensorrt_llm" | ||
) | ||
assert allgather_plg_creator is not None | ||
world_size = dist.get_world_size() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How might the converter get this info if it was in library?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not clear what is meant by library here? You mean the aten_ops_converters.py? Generally the converter should get this info when the distributed environment is initialized. It is implicitly done when using torhrun but we explicitly initialize this in the initialize_distributed_env()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok so we dont need a dist object? can we use that version here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes we could if we maintain a global variable for it and use that in the file. But the dist object would be required for initialization
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did you verify that numerical results are correct here?
Yes @narendasan , the numerical results come out to be correct for this example and the llama3 within 0.01 error threshold |
a27b719
to
b6f5980
Compare
group = trt.PluginField( | ||
"group", np.array(group, dtype=np.int32), trt.PluginFieldType.INT32 | ||
) | ||
p_dtype = trt.float16 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do these kernels only support FP16?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No they can support FP32 too
logger = logging.getLogger(__name__) | ||
|
||
|
||
def custom_fused_all_gather_op(args0, args1, args2): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lets call this something like tensorrt_fused_nccl_all_gather
or something
f"Plugin Name: {plugin_creator.name}, Namespace: {plugin_creator.plugin_namespace}, Version: {plugin_creator.plugin_version}" | ||
) | ||
|
||
@dynamo_tensorrt_converter(custom_fused_all_gather_op) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We want to start thinking about how these might get added as actual converters like how we support quantization. I think the global variable dependency is a issue. How might we work around that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes pulling in the global variable assuming that the environment variable is set and initialized in the initialization part can be done, instead of using the dist package
# Initialization | ||
initialize_distributed_env() | ||
# create a device mesh based on the given world_size. | ||
_world_size = int(os.environ["WORLD_SIZE"]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Things like this I am ok pulling in "globally", since we can assume the env variable is set and presumably this is what people are doing aready
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to Python style guidelines:
be6241c
to
0837fc2
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to Python style guidelines:
0837fc2
to
6707c6f
Compare
38335b9
to
06fb7a8
Compare
06fb7a8
to
509d917
Compare
509d917
to
6ffc284
Compare
6ffc284
to
e96ce78
Compare
|
||
logger.debug("Post-AOT Autograd graph:\n" + str(gm.graph)) | ||
|
||
gm = post_lowering(gm, settings) | ||
|
||
logger.debug("Lowered Input graph:\n " + str(gm.graph)) | ||
|
||
complex_nodes = find_complex_nodes(gm) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why isnt this part of lowering?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need to get the complex_nodes before I do the lowering pass of replace_complex_placeholder_to_tuple()
. Can put this in replace_complex_placeholder_to_tuple
lowering pass, but since its a util function and pertaining more to modify_complex_nodes
I put this in utils
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I mean seems to me these 3-4 lines are a pass that can be added to lowering
@@ -3590,3 +3592,76 @@ def aten_ops_full( | |||
fill_value=args[1], | |||
dtype=kwargs.get("dtype", None), | |||
) | |||
|
|||
|
|||
try: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Turn this into a utility function to load trtllm / plugins lib, return bool on success and use that to condition the converter
counter = 0 | ||
strategy = AllReduceStrategy.NCCL | ||
config = AllReduceConfig(0) | ||
_world_size = os.environ.get("WORLD_SIZE") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does WORLD_SIZE get baked into the engine? if i load from serialized do i need the env variable?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NoI dont think so. This relies on the initialize_distributed_env() function in tensor_parallel_dist_env.py to do that. In torchrun command it would implicitly do it, but since we do mpirun for the nccl commands TRT-LLM support we need to initialize the variables
e96ce78
to
d161946
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/backend/backends.py 2024-12-23 18:40:27.812736+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/backend/backends.py 2024-12-23 18:40:48.084051+00:00
@@ -133,11 +133,11 @@
gm = post_lowering(gm, settings)
logger.debug("Lowered Input graph:\n " + str(gm.graph))
complex_nodes = find_complex_nodes(gm)
- if (complex_nodes):
+ if complex_nodes:
replace_complex_placeholder_to_tuple(gm, complexInputIndices)
modify_complex_nodes(gm, complex_nodes)
torchtrt_inputs = prepare_inputs(
torch_inputs, disable_memory_format_check=True
@apbose run the linter |
os.environ["MASTER_PORT"] = str(port) | ||
# Note this will not work in the initialization here | ||
# You would need to set it externally as a user | ||
os.environ["trtllm_env"] = ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be all caps and be something like TRTLLM_PLUGINS_PATH
|
||
logger.debug("Post-AOT Autograd graph:\n" + str(gm.graph)) | ||
|
||
gm = post_lowering(gm, settings) | ||
|
||
logger.debug("Lowered Input graph:\n " + str(gm.graph)) | ||
|
||
complex_nodes = find_complex_nodes(gm) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I mean seems to me these 3-4 lines are a pass that can be added to lowering
a18ba8b
to
3148697
Compare
3148697
to
b77a971
Compare
visited_nodes.add(node) | ||
update_node_meta(node, fake_mode) | ||
for user in node.users: | ||
if ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we have this special case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is a terminating case for the Llama model complex placeholder nodes. The model is like complex placeholder->reshape->slice->complex mul
, we need the meta data for reshape and slice to be amended, stopping at mul node (we are removing the complex mul to a custom torchTRT mul later in modify_reshape_complex_nodes
)
This PR illustrates the use of nccl ops from TRT-LLM for the example
examples/distributed_inference/tensor_parallel_simple_example.py