Skip to content

Commit 9f2cdfe

Browse files
cspotcodeZuzu-Typ
andauthored
Fix #210: matrix multiplication of array of vectors (#212)
* fix * Fix matrix multiplication in arrays + Should fix #210 + Fixed matrix multiplication for `arr * arr`, `mat * arr` and `arr * mat` --------- Co-authored-by: Zuzu-Typ <[email protected]>
1 parent 97841d8 commit 9f2cdfe

File tree

2 files changed

+41
-19
lines changed

2 files changed

+41
-19
lines changed

PyGLM/type_methods/glmArray.h

Lines changed: 11 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5134,15 +5134,13 @@ static PyObject* glmArray_mul_T_MMUL(glmArray* arr1, glmArray* arr2) {
51345134
outArray->readonly = false;
51355135
outArray->reference = NULL;
51365136

5137-
Py_ssize_t n, arr1Stride, arr2Stride, trueShape1;
5137+
Py_ssize_t n, arr1Stride, arr2Stride;
51385138

51395139
if (arr1->glmType == PyGLM_TYPE_VEC) {
51405140
n = arr1->shape[0];
51415141
arr1Stride = 1;
51425142
arr2Stride = arr2->shape[1];
51435143

5144-
trueShape1 = 1;
5145-
51465144
outArray->glmType = PyGLM_TYPE_VEC;
51475145
outArray->shape[0] = arr2->shape[0];
51485146
outArray->shape[1] = 0;
@@ -5155,8 +5153,6 @@ static PyObject* glmArray_mul_T_MMUL(glmArray* arr1, glmArray* arr2) {
51555153
arr1Stride = arr1->shape[1];
51565154
arr2Stride = 0;
51575155

5158-
trueShape1 = arr1->shape[1];
5159-
51605156
outArray->glmType = PyGLM_TYPE_VEC;
51615157
outArray->shape[0] = arr1->shape[1];
51625158
outArray->shape[1] = 0;
@@ -5169,8 +5165,6 @@ static PyObject* glmArray_mul_T_MMUL(glmArray* arr1, glmArray* arr2) {
51695165
arr1Stride = arr1->shape[1];
51705166
arr2Stride = arr2->shape[1];
51715167

5172-
trueShape1 = arr1->shape[1];
5173-
51745168
outArray->glmType = PyGLM_TYPE_MAT;
51755169
outArray->shape[0] = arr2->shape[0];
51765170
outArray->shape[1] = arr1->shape[1];
@@ -5198,13 +5192,16 @@ static PyObject* glmArray_mul_T_MMUL(glmArray* arr1, glmArray* arr2) {
51985192
for (Py_ssize_t j = 0; j < outArrayRatio; j++) {
51995193
T result = (T)0;
52005194
for (Py_ssize_t k = 0; k < n; k++) {
5201-
T a = arr1DataPtr[k * arr1Stride + j % trueShape1];
5202-
T b = arr2DataPtr[k + (j / trueShape1) * arr2Stride];
5195+
T a = arr1DataPtr[k * arr1Stride + j % arr1Stride];
5196+
T b = arr2DataPtr[k + (j / arr1Stride) * arr2Stride];
52035197

52045198
result += a * b;
52055199
}
52065200
outArrayDataPtr[outArrayIndex++] = result;
52075201
}
5202+
// move pointers by one item
5203+
arr1DataPtr = reinterpret_cast<T*>(reinterpret_cast<char*>(arr1DataPtr) + arr1->itemSize);
5204+
arr2DataPtr = reinterpret_cast<T*>(reinterpret_cast<char*>(arr2DataPtr) + arr2->itemSize);
52085205
}
52095206

52105207
return (PyObject*)outArray;
@@ -5283,15 +5280,13 @@ static PyObject* glmArray_mulO_T(glmArray* arr, T* o, Py_ssize_t o_size, PyGLMTy
52835280
return (PyObject*)outArray;
52845281
}
52855282

5286-
Py_ssize_t n, arrStride, oStride, trueShape1;
5283+
Py_ssize_t n, arrStride, oStride;
52875284

52885285
if (arr->glmType == PyGLM_TYPE_VEC) {
52895286
n = arr->shape[0];
52905287
arrStride = 1;
52915288
oStride = pto->R;
52925289

5293-
trueShape1 = 1;
5294-
52955290
outArray->glmType = PyGLM_TYPE_VEC;
52965291
outArray->shape[0] = pto->C;
52975292
outArray->shape[1] = 0;
@@ -5304,8 +5299,6 @@ static PyObject* glmArray_mulO_T(glmArray* arr, T* o, Py_ssize_t o_size, PyGLMTy
53045299
arrStride = arr->shape[1];
53055300
oStride = 0;
53065301

5307-
trueShape1 = arr->shape[1];
5308-
53095302
outArray->glmType = PyGLM_TYPE_VEC;
53105303
outArray->shape[0] = arr->shape[1];
53115304
outArray->shape[1] = 0;
@@ -5318,8 +5311,6 @@ static PyObject* glmArray_mulO_T(glmArray* arr, T* o, Py_ssize_t o_size, PyGLMTy
53185311
arrStride = arr->shape[1];
53195312
oStride = pto->R;
53205313

5321-
trueShape1 = arr->shape[1];
5322-
53235314
outArray->glmType = PyGLM_TYPE_MAT;
53245315
outArray->shape[0] = pto->C;
53255316
outArray->shape[1] = arr->shape[1];
@@ -5346,13 +5337,14 @@ static PyObject* glmArray_mulO_T(glmArray* arr, T* o, Py_ssize_t o_size, PyGLMTy
53465337
for (Py_ssize_t j = 0; j < outArrayRatio; j++) {
53475338
T result = (T)0;
53485339
for (Py_ssize_t k = 0; k < n; k++) {
5349-
T a = arrDataPtr[k * arrStride + j % trueShape1];
5350-
T b = o[k + (j / trueShape1) * oStride];
5340+
T a = arrDataPtr[k * arrStride + j % arrStride];
5341+
T b = o[k + (j / arrStride) * oStride];
53515342

53525343
result = result + a * b;
53535344
}
53545345
outArrayDataPtr[outArrayIndex++] = result;
53555346
}
5347+
arrDataPtr = reinterpret_cast<T*>(reinterpret_cast<char*>(arrDataPtr) + arr->itemSize);
53565348
}
53575349

53585350
return (PyObject*)outArray;
@@ -5441,6 +5433,7 @@ static PyObject* glmArray_rmulO_T(glmArray* arr, T* o, Py_ssize_t o_size, PyGLMT
54415433
}
54425434
outArrayDataPtr[outArrayIndex++] = result;
54435435
}
5436+
arrDataPtr = reinterpret_cast<T*>(reinterpret_cast<char*>(arrDataPtr) + arr->itemSize);
54445437
}
54455438

54465439
return (PyObject*)outArray;

test/PyGLM_test.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3643,7 +3643,36 @@ def test_mat2x2():
36433643
##/core_type_mat2x2 ##
36443644
###/GLM TESTS ###
36453645

3646-
3646+
def test_array_matmul():
3647+
hitbox = glm.array(
3648+
glm.vec2(0., 1.),
3649+
glm.vec2(1., 1.),
3650+
glm.vec2(1., 0.),
3651+
glm.vec2(0., 0.),
3652+
)
3653+
rotation = glm.radians(90)
3654+
scale_x = 1.
3655+
scale_y = 2.
3656+
3657+
cos_rotation = glm.cos(rotation)
3658+
sin_rotation = glm.sin(rotation)
3659+
3660+
rotation_scale_matrix = glm.mat2x2(
3661+
scale_x * cos_rotation, -scale_y * sin_rotation,
3662+
scale_x * sin_rotation, scale_y * cos_rotation
3663+
)
3664+
hitbox_rotated = hitbox * rotation_scale_matrix
3665+
3666+
# I expect to see the points rotated 90 degrees, approximately:
3667+
# array(vec2(-2, 0), vec2(-2, 1), vec2(0, 1), vec2(0, 0))
3668+
assert glm.equal(hitbox_rotated[0].x, -2, 0.00001)
3669+
assert glm.equal(hitbox_rotated[0].y, 0, 0.00001)
3670+
assert glm.equal(hitbox_rotated[1].x, -2, 0.00001)
3671+
assert glm.equal(hitbox_rotated[1].y, 1, 0.00001)
3672+
assert glm.equal(hitbox_rotated[2].x, 0, 0.00001)
3673+
assert glm.equal(hitbox_rotated[2].y, 1, 0.00001)
3674+
assert glm.equal(hitbox_rotated[3].x, 0, 0.00001)
3675+
assert glm.equal(hitbox_rotated[3].y, 0, 0.00001)
36473676

36483677
### TEST TEST ###
36493678

0 commit comments

Comments
 (0)