@@ -5356,6 +5356,7 @@ def unsqueeze_sample_generator(op, device, dtype, requires_grad, **kwargs):
53565356
53575357unsqueeze_opinfo = OpInfo (
53585358 clang .unsqueeze ,
5359+ supports_grad = True ,
53595360 sample_input_generator = unsqueeze_sample_generator ,
53605361 jax_reference = jax .lax .expand_dims if JAX_AVAILABLE else None ,
53615362 test_directives = (
@@ -6018,6 +6019,53 @@ def topk_error_generator(op, device, **kwargs):
60186019reduction_ops .append (topk_opinfo )
60196020
60206021
6022+ def atleast_1d2d3d_sample_generator (op , device , dtype , requires_grad , ** kwargs ):
6023+ make = partial (make_tensor , dtype = dtype , device = device , requires_grad = requires_grad )
6024+
6025+ cases = (
6026+ (),
6027+ (4 ,),
6028+ (5 , 5 ),
6029+ (6 , 7 , 8 ),
6030+ (3 , 3 , 3 , 3 ),
6031+ )
6032+
6033+ for c in cases :
6034+ yield SampleInput (make (c ))
6035+
6036+ yield SampleInput (make (()), make ((2 ,)))
6037+ yield SampleInput (make ((2 ,)), make ((5 , 5 )))
6038+ yield SampleInput (make (()), make ((2 ,)), make ((4 , 4 )))
6039+ yield SampleInput (make (2 , 3 ), make (4 , 5 ), make (6 , 6 , 6 ), make (5 , 5 , 5 , 5 ))
6040+
6041+
6042+ atleast_1d_opinfo = OpInfo (
6043+ ltorch .atleast_1d ,
6044+ supports_grad = True ,
6045+ sample_input_generator = atleast_1d2d3d_sample_generator ,
6046+ torch_reference = torch .atleast_1d ,
6047+ )
6048+ reduction_ops .append (atleast_1d_opinfo )
6049+
6050+
6051+ atleast_2d_opinfo = OpInfo (
6052+ ltorch .atleast_2d ,
6053+ supports_grad = True ,
6054+ sample_input_generator = atleast_1d2d3d_sample_generator ,
6055+ torch_reference = torch .atleast_2d ,
6056+ )
6057+ reduction_ops .append (atleast_2d_opinfo )
6058+
6059+
6060+ atleast_3d_opinfo = OpInfo (
6061+ ltorch .atleast_3d ,
6062+ supports_grad = True ,
6063+ sample_input_generator = atleast_1d2d3d_sample_generator ,
6064+ torch_reference = torch .atleast_3d ,
6065+ )
6066+ reduction_ops .append (atleast_3d_opinfo )
6067+
6068+
60216069opinfos .extend (reduction_ops )
60226070
60236071
0 commit comments