-
Notifications
You must be signed in to change notification settings - Fork 290
Open
Description
Hi, it seems like there is some inconsistency with how rows/cols are interpreted in cute::m3
.
Specifically, a property about outer_product
consistency fails to hold:
Given three
However, in cute_math, the outer product of
It's not entirely clear which way this should be fixed, because outer_product
constructs the m3
with rows(...)
(which is correct if these are matrix rows), however cute::mul
interprets those as columns, hence the transposition.
(found via automated fuzzing)
The issue is demonstrated in the following testcase:
testcase.cpp
#include <cmath>
#include <cassert>
#include <cstdio>
#include "cute_math.h"
int main()
{
// Small, deterministic values to show the error clearly.
cute::v3 u(1.0f, 2.0f, 3.0f);
cute::v3 v(4.0f, 5.0f, 6.0f);
cute::v3 x(7.0f, 8.0f, 9.0f);
cute::m3 M = cute::outer_product(u, v); // Expected: M = u v^T
cute::v3 lhs = cute::mul(M, x);
printf("lhs: %f, %f, %f\n", cute::getx(lhs), cute::gety(lhs), cute::getz(lhs));
// Property for true outer product: M*x == u * dot(v, x)
float s = cute::dot(v, x);
cute::v3 rhs = u * s;
printf("rhs: %f, %f, %f\n", cute::getx(rhs), cute::gety(rhs), cute::getz(rhs));
float dx = cute::getx(lhs) - cute::getx(rhs);
float dy = cute::gety(lhs) - cute::gety(rhs);
float dz = cute::getz(lhs) - cute::getz(rhs);
// This assertion fails with the current implementation because
// outer_product appears to return v u^T instead of u v^T.
float tol = 1e-5f;
assert(std::fabs(dx) <= tol && std::fabs(dy) <= tol && std::fabs(dz) <= tol);
return 0;
}
output
lhs: 200.000000, 250.000000, 300.000000
rhs: 122.000000, 244.000000, 366.000000
Metadata
Metadata
Assignees
Labels
No labels