@@ -946,8 +946,9 @@ def fusion_pass(self, trace: TraceCtx) -> TraceCtx:
946
946
register_executor (ex )
947
947
948
948
949
- def register_supported (id : Hashable , translator : Callable , checker : Callable ):
950
- ex .register_supported (id , checker )
949
+ def register_supported (sym_or_id : Hashable , translator : Callable , checker : Callable ):
950
+ ex .register_supported (sym_or_id , checker )
951
+ id = sym_or_id .id if isinstance (sym_or_id , Symbol ) else sym_or_id
951
952
_translation_map [id ] = translator
952
953
953
954
@@ -2582,3 +2583,47 @@ def scaled_dot_product_flash_attention_grad(
2582
2583
execution_transform = scaled_dot_product_flash_attention ,
2583
2584
grad_transform = scaled_dot_product_flash_attention_grad ,
2584
2585
)
2586
+
2587
+
2588
+ def _embedding_check (
2589
+ input : TensorProxy ,
2590
+ weight : TensorProxy ,
2591
+ padding_idx : None | int ,
2592
+ max_norm : None | float ,
2593
+ norm_type : None | float ,
2594
+ scale_grad_by_freq : None | bool ,
2595
+ sparse : None | bool ,
2596
+ ) -> bool :
2597
+ if nvfuser_version () < LooseVersion ("0.2.25" ):
2598
+ return False
2599
+ enable_embedding : None | bool = get_compile_option ("nv_enable_embedding" , "Enable nvFuser embedding." )
2600
+ if not enable_embedding :
2601
+ return False
2602
+ # Verify input and weight are supported tensors.
2603
+ if not are_supported_tensors (input , weight ) or (weight .ndim != 2 ):
2604
+ return False
2605
+ return True
2606
+
2607
+
2608
+ def embedding (
2609
+ input : TensorProxy ,
2610
+ weight : TensorProxy ,
2611
+ padding_idx : None | int = None ,
2612
+ max_norm : None | float = None ,
2613
+ norm_type : None | float = 2.0 ,
2614
+ scale_grad_by_freq : None | bool = False ,
2615
+ sparse : None | bool = False ,
2616
+ * ,
2617
+ fd : FusionDefinition ,
2618
+ lc_to_nv_map : dict ,
2619
+ ) -> Any :
2620
+ inputs = [input , weight , padding_idx , max_norm , norm_type , scale_grad_by_freq , sparse ]
2621
+ nv_inputs = []
2622
+ for inp in inputs :
2623
+ nv_inp = getnv (inp , fd , lc_to_nv_map ) if inp is not None else None
2624
+ nv_inputs .append (nv_inp )
2625
+ return fd .ops .embedding_fwd (* nv_inputs )
2626
+
2627
+
2628
+ register_supported (PrimIDs .EMBEDDING , embedding , _embedding_check )
2629
+ register_supported (ltorch .embedding , embedding , _embedding_check )
0 commit comments