@@ -408,8 +408,8 @@ WaitHandleBase *_idtr_copy_reshape(SHARPY::DTypeId sharpytype,
408
408
std::multiplies<int64_t >());
409
409
assert (icSz == std::accumulate (&iDataShapePtr[1 ], &iDataShapePtr[iNDims], 1 ,
410
410
std::multiplies<int64_t >()));
411
- int64_t mySz = icSz * iGShapePtr [0 ];
412
- if (mySz / icSz != iGShapePtr [0 ]) {
411
+ int64_t mySz = icSz * iDataShapePtr [0 ];
412
+ if (mySz / icSz != iDataShapePtr [0 ]) {
413
413
throw std::overflow_error (" Fatal: Integer overflow in reshape" );
414
414
}
415
415
int64_t myOff = iOffsPtr[0 ] * icSz;
@@ -425,8 +425,8 @@ WaitHandleBase *_idtr_copy_reshape(SHARPY::DTypeId sharpytype,
425
425
std::multiplies<int64_t >());
426
426
assert (oCSz == std::accumulate (&oDataShapePtr[1 ], &oDataShapePtr[oNDims], 1 ,
427
427
std::multiplies<int64_t >()));
428
- int64_t myOSz = oCSz * oGShapePtr [0 ];
429
- if (myOSz / oCSz != oGShapePtr [0 ]) {
428
+ int64_t myOSz = oCSz * oDataShapePtr [0 ];
429
+ if (myOSz / oCSz != oDataShapePtr [0 ]) {
430
430
throw std::overflow_error (" Fatal: Integer overflow in reshape" );
431
431
}
432
432
int64_t myOOff = oOffsPtr[0 ] * oCSz;
@@ -495,6 +495,11 @@ WaitHandleBase *_idtr_copy_reshape(SHARPY::DTypeId sharpytype,
495
495
auto hdl = tc->alltoall (outbuff.data (), sszs.data (), soffs.data (), sharpytype,
496
496
oDataPtr, rszs.data (), roffs.data ());
497
497
498
+ if (no_async) {
499
+ tc->wait (hdl);
500
+ return nullptr ;
501
+ }
502
+
498
503
auto wait = [tc = tc, hdl = hdl, outbuff = std::move (outbuff),
499
504
sszs = std::move (sszs), soffs = std::move (soffs),
500
505
rszs = std::move (rszs),
0 commit comments