Skip to content

Commit a79f632

Browse files
committed
fixing var renaming error
1 parent da892d6 commit a79f632

File tree

2 files changed

+15
-10
lines changed

2 files changed

+15
-10
lines changed

src/idtr.cpp

+9-4
Original file line numberDiff line numberDiff line change
@@ -408,8 +408,8 @@ WaitHandleBase *_idtr_copy_reshape(SHARPY::DTypeId sharpytype,
408408
std::multiplies<int64_t>());
409409
assert(icSz == std::accumulate(&iDataShapePtr[1], &iDataShapePtr[iNDims], 1,
410410
std::multiplies<int64_t>()));
411-
int64_t mySz = icSz * iGShapePtr[0];
412-
if (mySz / icSz != iGShapePtr[0]) {
411+
int64_t mySz = icSz * iDataShapePtr[0];
412+
if (mySz / icSz != iDataShapePtr[0]) {
413413
throw std::overflow_error("Fatal: Integer overflow in reshape");
414414
}
415415
int64_t myOff = iOffsPtr[0] * icSz;
@@ -425,8 +425,8 @@ WaitHandleBase *_idtr_copy_reshape(SHARPY::DTypeId sharpytype,
425425
std::multiplies<int64_t>());
426426
assert(oCSz == std::accumulate(&oDataShapePtr[1], &oDataShapePtr[oNDims], 1,
427427
std::multiplies<int64_t>()));
428-
int64_t myOSz = oCSz * oGShapePtr[0];
429-
if (myOSz / oCSz != oGShapePtr[0]) {
428+
int64_t myOSz = oCSz * oDataShapePtr[0];
429+
if (myOSz / oCSz != oDataShapePtr[0]) {
430430
throw std::overflow_error("Fatal: Integer overflow in reshape");
431431
}
432432
int64_t myOOff = oOffsPtr[0] * oCSz;
@@ -495,6 +495,11 @@ WaitHandleBase *_idtr_copy_reshape(SHARPY::DTypeId sharpytype,
495495
auto hdl = tc->alltoall(outbuff.data(), sszs.data(), soffs.data(), sharpytype,
496496
oDataPtr, rszs.data(), roffs.data());
497497

498+
if (no_async) {
499+
tc->wait(hdl);
500+
return nullptr;
501+
}
502+
498503
auto wait = [tc = tc, hdl = hdl, outbuff = std::move(outbuff),
499504
sszs = std::move(sszs), soffs = std::move(soffs),
500505
rszs = std::move(rszs),

test/test_ewb.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -87,12 +87,12 @@ def doit(aapi, **kwargs):
8787
a = aapi.reshape(
8888
aapi.arange(0, 64, 1, dtype=aapi.float32, **kwargs), [8, 8]
8989
)
90-
# b = aapi.reshape(
91-
# aapi.arange(0, 64, 1, dtype=aapi.float32, **kwargs), [8, 8]
92-
# )
93-
# c = a[2:6, 0:8]
94-
# d = b[0:8:2, 0:8]
95-
return a
90+
b = aapi.reshape(
91+
aapi.arange(0, 64, 1, dtype=aapi.float32, **kwargs), [8, 8]
92+
)
93+
c = a[2:6, 0:8]
94+
d = b[0:8:2, 0:8]
95+
return c + d
9696

9797
assert runAndCompare(doit)
9898

0 commit comments

Comments
 (0)