@@ -42,53 +42,108 @@ mp_obj_t compare_bincount(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_
4242 #if ULAB_MAX_DIMS > 1
4343 // no need to check anything, if the maximum number of dimensions is 1
4444 if (input -> ndim != 1 ) {
45- mp_raise_ValueError (MP_ERROR_TEXT ("object too deep for desired arrayy " ));
45+ mp_raise_ValueError (MP_ERROR_TEXT ("object too deep for desired array " ));
4646 }
4747 #endif
4848 if ((input -> dtype != NDARRAY_UINT8 ) && (input -> dtype != NDARRAY_UINT16 )) {
4949 mp_raise_TypeError (MP_ERROR_TEXT ("cannot cast array data from dtype" ));
5050 }
5151
5252 // first find the maximum of the array, and figure out how long the result should be
53- uint16_t max = 0 ;
53+ size_t length = 0 ;
5454 int32_t stride = input -> strides [ULAB_MAX_DIMS - 1 ];
5555 if (input -> dtype == NDARRAY_UINT8 ) {
5656 uint8_t * iarray = (uint8_t * )input -> array ;
5757 for (size_t i = 0 ; i < input -> len ; i ++ ) {
58- if (* iarray > max ) {
59- max = * iarray ;
58+ if (* iarray > length ) {
59+ length = * iarray ;
6060 }
6161 iarray += stride ;
6262 }
6363 } else if (input -> dtype == NDARRAY_UINT16 ) {
6464 stride /= 2 ;
6565 uint16_t * iarray = (uint16_t * )input -> array ;
6666 for (size_t i = 0 ; i < input -> len ; i ++ ) {
67- if (* iarray > max ) {
68- max = * iarray ;
67+ if (* iarray > length ) {
68+ length = * iarray ;
6969 }
7070 iarray += stride ;
7171 }
7272 }
73- ndarray_obj_t * result = ndarray_new_linear_array ( max + 1 , NDARRAY_UINT16 ) ;
73+ length += 1 ;
7474
75- // now we can do the binning
76- uint16_t * rarray = (uint16_t * )result -> array ;
75+ if (args [2 ].u_obj != mp_const_none ) {
76+ int32_t minlength = mp_obj_get_int (args [2 ].u_obj );
77+ if (minlength < 0 ) {
78+ mp_raise_ValueError (MP_ERROR_TEXT ("minlength must not be negative" ));
79+ }
80+ if ((size_t )minlength > length ) {
81+ length = minlength ;
82+ }
83+ }
7784
78- if (input -> dtype == NDARRAY_UINT8 ) {
79- uint8_t * iarray = (uint8_t * )input -> array ;
80- for (size_t i = 0 ; i < input -> len ; i ++ ) {
81- rarray [* iarray ] += 1 ;
82- iarray += stride ;
85+ ndarray_obj_t * result = NULL ;
86+ ndarray_obj_t * weights = NULL ;
87+
88+ if (args [1 ].u_obj == mp_const_none ) {
89+ result = ndarray_new_linear_array (length , NDARRAY_UINT16 );
90+ } else {
91+ if (!mp_obj_is_type (args [1 ].u_obj , & ulab_ndarray_type )) {
92+ mp_raise_TypeError (MP_ERROR_TEXT ("input must be an ndarray" ));
8393 }
84- } else if (input -> dtype == NDARRAY_UINT16 ) {
85- uint16_t * iarray = (uint16_t * )input -> array ;
86- for (size_t i = 0 ; i < input -> len ; i ++ ) {
87- rarray [* iarray ] += 1 ;
88- iarray += stride ;
94+ weights = MP_OBJ_TO_PTR (args [1 ].u_obj );
95+ result = ndarray_new_linear_array (length , NDARRAY_FLOAT );
96+ }
97+
98+ // now we can do the binning
99+ if (result -> dtype == NDARRAY_UINT16 ) {
100+ uint16_t * rarray = (uint16_t * )result -> array ;
101+ if (input -> dtype == NDARRAY_UINT8 ) {
102+ uint8_t * iarray = (uint8_t * )input -> array ;
103+ for (size_t i = 0 ; i < input -> len ; i ++ ) {
104+ rarray [* iarray ] += 1 ;
105+ iarray += stride ;
106+ }
107+ } else if (input -> dtype == NDARRAY_UINT16 ) {
108+ uint16_t * iarray = (uint16_t * )input -> array ;
109+ for (size_t i = 0 ; i < input -> len ; i ++ ) {
110+ rarray [* iarray ] += 1 ;
111+ iarray += stride ;
112+ }
113+ }
114+ } else {
115+ mp_float_t * rarray = (mp_float_t * )result -> array ;
116+ if (input -> dtype == NDARRAY_UINT8 ) {
117+ uint8_t * iarray = (uint8_t * )input -> array ;
118+ for (size_t i = 0 ; i < input -> len ; i ++ ) {
119+ rarray [* iarray ] += MICROPY_FLOAT_CONST (1.0 );
120+ iarray += stride ;
121+ }
122+ } else if (input -> dtype == NDARRAY_UINT16 ) {
123+ uint16_t * iarray = (uint16_t * )input -> array ;
124+ for (size_t i = 0 ; i < input -> len ; i ++ ) {
125+ rarray [* iarray ] += MICROPY_FLOAT_CONST (1.0 );
126+ iarray += stride ;
127+ }
89128 }
90129 }
91130
131+ if (weights != NULL ) {
132+ mp_float_t (* get_weights )(void * ) = ndarray_get_float_function (weights -> dtype );
133+ mp_float_t * rarray = (mp_float_t * )result -> array ;
134+ uint8_t * warray = (uint8_t * )weights -> array ;
135+
136+ size_t fill_length = result -> len ;
137+ if (weights -> len < result -> len ) {
138+ fill_length = weights -> len ;
139+ }
140+
141+ for (size_t i = 0 ; i < fill_length ; i ++ ) {
142+ * rarray = * rarray * get_weights (warray );
143+ rarray ++ ;
144+ warray += weights -> strides [ULAB_MAX_DIMS - 1 ];
145+ }
146+ }
92147 return MP_OBJ_FROM_PTR (result );
93148}
94149
0 commit comments