@@ -569,6 +569,9 @@ typedef struct state {
569569
570570/* Forward declarations */
571571static me_expr * new_expr (const int type , const me_expr * parameters []);
572+ static me_dtype infer_output_type (const me_expr * n );
573+ static void private_eval (const me_expr * n );
574+ static void eval_reduction (const me_expr * n , int output_nitems );
572575static double conj_wrapper (double x );
573576static double imag_wrapper (double x );
574577static double real_wrapper (double x );
@@ -631,15 +634,49 @@ static bool contains_reduction(const me_expr* n) {
631634}
632635
633636static bool reduction_usage_is_valid (const me_expr * n ) {
634- if (!is_reduction_node (n )) return false;
635- me_expr * arg = (me_expr * )n -> parameters [0 ];
636- if (!arg ) return false;
637- if (n -> function == (void * )min_reduce || n -> function == (void * )max_reduce ) {
638- if (arg -> dtype == ME_COMPLEX64 || arg -> dtype == ME_COMPLEX128 ) {
639- return false;
637+ if (!n ) return true;
638+ if (is_reduction_node (n )) {
639+ me_expr * arg = (me_expr * )n -> parameters [0 ];
640+ if (!arg ) return false;
641+ if (contains_reduction (arg )) return false;
642+ me_dtype arg_type = infer_output_type (arg );
643+ if (n -> function == (void * )min_reduce || n -> function == (void * )max_reduce ) {
644+ if (arg_type == ME_COMPLEX64 || arg_type == ME_COMPLEX128 ) {
645+ return false;
646+ }
640647 }
648+ return true;
649+ }
650+
651+ switch (TYPE_MASK (n -> type )) {
652+ case ME_FUNCTION0 :
653+ case ME_FUNCTION1 :
654+ case ME_FUNCTION2 :
655+ case ME_FUNCTION3 :
656+ case ME_FUNCTION4 :
657+ case ME_FUNCTION5 :
658+ case ME_FUNCTION6 :
659+ case ME_FUNCTION7 :
660+ case ME_CLOSURE0 :
661+ case ME_CLOSURE1 :
662+ case ME_CLOSURE2 :
663+ case ME_CLOSURE3 :
664+ case ME_CLOSURE4 :
665+ case ME_CLOSURE5 :
666+ case ME_CLOSURE6 :
667+ case ME_CLOSURE7 :
668+ {
669+ const int arity = ARITY (n -> type );
670+ for (int i = 0 ; i < arity ; i ++ ) {
671+ if (!reduction_usage_is_valid ((const me_expr * )n -> parameters [i ])) {
672+ return false;
673+ }
674+ }
675+ return true;
676+ }
677+ default :
678+ return true;
641679 }
642- return TYPE_MASK (arg -> type ) == ME_VARIABLE || TYPE_MASK (arg -> type ) == ME_CONSTANT ;
643680}
644681
645682/* Infer computation type from expression tree (for evaluation) */
@@ -3698,7 +3735,12 @@ typedef float (*me_fun1_f32)(float);
36983735 SQRT_FUNC , SIN_FUNC , COS_FUNC , EXP_FUNC , LOG_FUNC , FABS_FUNC , POW_FUNC , \
36993736 VEC_CONJ ) \
37003737static void me_eval_##SUFFIX(const me_expr *n) { \
3701- if (!n || !n->output || n->nitems <= 0) return; \
3738+ if (!n || !n->output) return; \
3739+ if (is_reduction_node(n)) { \
3740+ eval_reduction(n, n->nitems); \
3741+ return; \
3742+ } \
3743+ if (n->nitems <= 0) return; \
37023744 \
37033745 int i, j; \
37043746 const int arity = ARITY(n->type); \
@@ -4364,14 +4406,139 @@ static bool all_variables_match_type(const me_expr* n, me_dtype target_type) {
43644406 return true;
43654407}
43664408
4367- static void eval_reduction (const me_expr * n ) {
4409+ static void broadcast_reduction_output (void * output , me_dtype dtype , int output_nitems ) {
4410+ if (!output || output_nitems <= 1 ) return ;
4411+ switch (dtype ) {
4412+ case ME_BOOL :
4413+ {
4414+ bool val = ((bool * )output )[0 ];
4415+ for (int i = 1 ; i < output_nitems ; i ++ ) {
4416+ ((bool * )output )[i ] = val ;
4417+ }
4418+ break ;
4419+ }
4420+ case ME_INT8 :
4421+ {
4422+ int8_t val = ((int8_t * )output )[0 ];
4423+ for (int i = 1 ; i < output_nitems ; i ++ ) {
4424+ ((int8_t * )output )[i ] = val ;
4425+ }
4426+ break ;
4427+ }
4428+ case ME_INT16 :
4429+ {
4430+ int16_t val = ((int16_t * )output )[0 ];
4431+ for (int i = 1 ; i < output_nitems ; i ++ ) {
4432+ ((int16_t * )output )[i ] = val ;
4433+ }
4434+ break ;
4435+ }
4436+ case ME_INT32 :
4437+ {
4438+ int32_t val = ((int32_t * )output )[0 ];
4439+ for (int i = 1 ; i < output_nitems ; i ++ ) {
4440+ ((int32_t * )output )[i ] = val ;
4441+ }
4442+ break ;
4443+ }
4444+ case ME_INT64 :
4445+ {
4446+ int64_t val = ((int64_t * )output )[0 ];
4447+ for (int i = 1 ; i < output_nitems ; i ++ ) {
4448+ ((int64_t * )output )[i ] = val ;
4449+ }
4450+ break ;
4451+ }
4452+ case ME_UINT8 :
4453+ {
4454+ uint8_t val = ((uint8_t * )output )[0 ];
4455+ for (int i = 1 ; i < output_nitems ; i ++ ) {
4456+ ((uint8_t * )output )[i ] = val ;
4457+ }
4458+ break ;
4459+ }
4460+ case ME_UINT16 :
4461+ {
4462+ uint16_t val = ((uint16_t * )output )[0 ];
4463+ for (int i = 1 ; i < output_nitems ; i ++ ) {
4464+ ((uint16_t * )output )[i ] = val ;
4465+ }
4466+ break ;
4467+ }
4468+ case ME_UINT32 :
4469+ {
4470+ uint32_t val = ((uint32_t * )output )[0 ];
4471+ for (int i = 1 ; i < output_nitems ; i ++ ) {
4472+ ((uint32_t * )output )[i ] = val ;
4473+ }
4474+ break ;
4475+ }
4476+ case ME_UINT64 :
4477+ {
4478+ uint64_t val = ((uint64_t * )output )[0 ];
4479+ for (int i = 1 ; i < output_nitems ; i ++ ) {
4480+ ((uint64_t * )output )[i ] = val ;
4481+ }
4482+ break ;
4483+ }
4484+ case ME_FLOAT32 :
4485+ {
4486+ float val = ((float * )output )[0 ];
4487+ for (int i = 1 ; i < output_nitems ; i ++ ) {
4488+ ((float * )output )[i ] = val ;
4489+ }
4490+ break ;
4491+ }
4492+ case ME_FLOAT64 :
4493+ {
4494+ double val = ((double * )output )[0 ];
4495+ for (int i = 1 ; i < output_nitems ; i ++ ) {
4496+ ((double * )output )[i ] = val ;
4497+ }
4498+ break ;
4499+ }
4500+ case ME_COMPLEX64 :
4501+ {
4502+ float _Complex val = ((float _Complex * )output )[0 ];
4503+ for (int i = 1 ; i < output_nitems ; i ++ ) {
4504+ ((float _Complex * )output )[i ] = val ;
4505+ }
4506+ break ;
4507+ }
4508+ case ME_COMPLEX128 :
4509+ {
4510+ double _Complex val = ((double _Complex * )output )[0 ];
4511+ for (int i = 1 ; i < output_nitems ; i ++ ) {
4512+ ((double _Complex * )output )[i ] = val ;
4513+ }
4514+ break ;
4515+ }
4516+ default :
4517+ break ;
4518+ }
4519+ }
4520+
4521+ static void eval_reduction (const me_expr * n , int output_nitems ) {
43684522 if (!n || !n -> output || !is_reduction_node (n )) return ;
4523+ if (output_nitems <= 0 ) return ;
43694524
43704525 me_expr * arg = (me_expr * )n -> parameters [0 ];
43714526 if (!arg ) return ;
43724527
43734528 const int nitems = n -> nitems ;
43744529 me_dtype arg_type = arg -> dtype ;
4530+ if (arg -> type != ME_CONSTANT && arg -> type != ME_VARIABLE ) {
4531+ arg_type = infer_output_type (arg );
4532+ if (nitems > 0 ) {
4533+ if (!arg -> output ) {
4534+ arg -> output = malloc ((size_t )nitems * dtype_size (arg_type ));
4535+ if (!arg -> output ) return ;
4536+ }
4537+ arg -> nitems = nitems ;
4538+ arg -> dtype = arg_type ;
4539+ private_eval (arg );
4540+ }
4541+ }
43754542 me_dtype result_type = reduction_output_dtype (arg_type , n -> function );
43764543 me_dtype output_type = n -> dtype ;
43774544 bool is_prod = n -> function == (void * )prod_reduce ;
@@ -4383,7 +4550,7 @@ static void eval_reduction(const me_expr* n) {
43834550 void * write_ptr = n -> output ;
43844551 void * temp_output = NULL ;
43854552 if (output_type != result_type ) {
4386- temp_output = malloc (dtype_size (result_type ));
4553+ temp_output = malloc (( size_t ) output_nitems * dtype_size (result_type ));
43874554 if (!temp_output ) return ;
43884555 write_ptr = temp_output ;
43894556 }
@@ -4627,7 +4794,13 @@ static void eval_reduction(const me_expr* n) {
46274794 }
46284795 }
46294796 }
4630- else if (arg -> type == ME_VARIABLE ) {
4797+ else {
4798+ const void * saved_bound = arg -> bound ;
4799+ int saved_type = arg -> type ;
4800+ if (arg -> type != ME_VARIABLE ) {
4801+ ((me_expr * )arg )-> bound = arg -> output ;
4802+ ((me_expr * )arg )-> type = ME_VARIABLE ;
4803+ }
46314804 switch (arg_type ) {
46324805 case ME_BOOL :
46334806 {
@@ -5140,12 +5313,21 @@ static void eval_reduction(const me_expr* n) {
51405313 default :
51415314 break ;
51425315 }
5316+ if (saved_type != ME_VARIABLE ) {
5317+ ((me_expr * )arg )-> bound = saved_bound ;
5318+ ((me_expr * )arg )-> type = saved_type ;
5319+ }
5320+ }
5321+
5322+ {
5323+ me_dtype write_type = temp_output ? result_type : output_type ;
5324+ broadcast_reduction_output (write_ptr , write_type , output_nitems );
51435325 }
51445326
51455327 if (temp_output ) {
51465328 convert_func_t conv = get_convert_func (result_type , output_type );
51475329 if (conv ) {
5148- conv (temp_output , n -> output , 1 );
5330+ conv (temp_output , n -> output , output_nitems );
51495331 }
51505332 free (temp_output );
51515333 }
@@ -5155,7 +5337,7 @@ static void private_eval(const me_expr* n) {
51555337 if (!n ) return ;
51565338
51575339 if (is_reduction_node (n )) {
5158- eval_reduction (n );
5340+ eval_reduction (n , 1 );
51595341 return ;
51605342 }
51615343
0 commit comments