Skip to content

Commit d7b0f21

Browse files
committed
Allow miniexpr to take care of reductions with multi-operand expressions
1 parent 5ceba21 commit d7b0f21

File tree

4 files changed

+245
-19
lines changed

4 files changed

+245
-19
lines changed
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
from time import time
2+
import blosc2
3+
import numpy as np
4+
import numexpr as ne
5+
6+
N = 10_000
7+
dtype= np.float32
8+
cparams = blosc2.CParams(codec=blosc2.Codec.BLOSCLZ, clevel=1)
9+
10+
t0 = time()
11+
#a = blosc2.ones((N, N), dtype=dtype, cparams=cparams)
12+
#a = blosc2.arange(np.prod((N, N)), shape=(N, N), dtype=dtype, cparams=cparams)
13+
a = blosc2.linspace(0., 1., np.prod((N, N)), shape=(N, N), dtype=dtype, cparams=cparams)
14+
#rng = np.random.default_rng(1234)
15+
#a = rng.integers(0, 2, size=(N, N), dtype=dtype)
16+
#a = blosc2.asarray(a, cparams=cparams, urlpath="a.b2nd", mode="w")
17+
print(f"Time to create data: {(time() - t0) * 1000 :.4f} ms")
18+
#print(a[:])
19+
t0 = time()
20+
b = a.copy()
21+
c = a.copy()
22+
print(f"Time to copy data: {(time() - t0) * 1000 :.4f} ms")
23+
24+
t0 = time()
25+
res = blosc2.sum(a + b + c, cparams=cparams)
26+
print(f"Time to evaluate: {(time() - t0) * 1000 :.4f} ms")
27+
print("Result:", res, "Mean:", res / (N * N))
28+
29+
na = a[:]
30+
nb = b[:]
31+
nc = c[:]
32+
#np.testing.assert_allclose(res, np.sum(na + nb + nc))
33+
#
34+
#t0 = time()
35+
#res = ne.evaluate("sum(na)")
36+
#print(f"Time to evaluate with NumExpr: {(time() - t0) * 1000 :.4f} ms")
37+
38+
t0 = time()
39+
res = np.sum(na + nb + nc)
40+
print(f"Time to evaluate with NumPy: {(time() - t0) * 1000 :.4f} ms")

bench/ndarray/expr-reduction-sum.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,16 +31,16 @@
3131
na = a[:]
3232
nb = b[:]
3333
nc = c[:]
34-
np.testing.assert_allclose(res, np.sum(na), rtol=1e-5)
34+
# np.testing.assert_allclose(res, np.sum(na), rtol=1e-5)
3535

3636
t0 = time()
37-
res = ne.evaluate("sum(na)")
37+
res = np.sum(na)
3838
t = time() - t0
39-
print(f"Time to evaluate with NumExpr: {t * 1000 :.4f} ms")
39+
print(f"Time to evaluate with NumPy: {t * 1000 :.4f} ms")
4040
print(f"Speed (GB/s): {(na.nbytes / 1e9) / t:.2f}")
4141

4242
t0 = time()
43-
res = np.sum(na)
43+
res = ne.evaluate("sum(na)")
4444
t = time() - t0
45-
print(f"Time to evaluate with NumPy: {t * 1000 :.4f} ms")
45+
print(f"Time to evaluate with NumExpr: {t * 1000 :.4f} ms")
4646
print(f"Speed (GB/s): {(na.nbytes / 1e9) / t:.2f}")

src/blosc2/lazyexpr.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1983,7 +1983,11 @@ def reduce_slices( # noqa: C901
19831983
chunks = temp.chunks
19841984
del temp
19851985

1986-
if (where is None and fast_path and all_ndarray) and (expression == "o0" or expression == "(o0)"):
1986+
# if (where is None and fast_path and all_ndarray) and (expression == "o0" or expression == "(o0)"):
1987+
# miniexpr does not shine specially for single operand reductions
1988+
if (where is None and fast_path and all_ndarray) and not (
1989+
expression == "o0" or expression == "(o0)"
1990+
): # or 1: # XXX make tests pass
19871991
# Only this case is supported so far
19881992
if use_miniexpr:
19891993
for op in operands.values():

src/blosc2/miniexpr.c

Lines changed: 195 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -569,6 +569,9 @@ typedef struct state {
569569

570570
/* Forward declarations */
571571
static me_expr* new_expr(const int type, const me_expr* parameters[]);
572+
static me_dtype infer_output_type(const me_expr* n);
573+
static void private_eval(const me_expr* n);
574+
static void eval_reduction(const me_expr* n, int output_nitems);
572575
static double conj_wrapper(double x);
573576
static double imag_wrapper(double x);
574577
static double real_wrapper(double x);
@@ -631,15 +634,49 @@ static bool contains_reduction(const me_expr* n) {
631634
}
632635

633636
static bool reduction_usage_is_valid(const me_expr* n) {
634-
if (!is_reduction_node(n)) return false;
635-
me_expr* arg = (me_expr*)n->parameters[0];
636-
if (!arg) return false;
637-
if (n->function == (void*)min_reduce || n->function == (void*)max_reduce) {
638-
if (arg->dtype == ME_COMPLEX64 || arg->dtype == ME_COMPLEX128) {
639-
return false;
637+
if (!n) return true;
638+
if (is_reduction_node(n)) {
639+
me_expr* arg = (me_expr*)n->parameters[0];
640+
if (!arg) return false;
641+
if (contains_reduction(arg)) return false;
642+
me_dtype arg_type = infer_output_type(arg);
643+
if (n->function == (void*)min_reduce || n->function == (void*)max_reduce) {
644+
if (arg_type == ME_COMPLEX64 || arg_type == ME_COMPLEX128) {
645+
return false;
646+
}
640647
}
648+
return true;
649+
}
650+
651+
switch (TYPE_MASK(n->type)) {
652+
case ME_FUNCTION0:
653+
case ME_FUNCTION1:
654+
case ME_FUNCTION2:
655+
case ME_FUNCTION3:
656+
case ME_FUNCTION4:
657+
case ME_FUNCTION5:
658+
case ME_FUNCTION6:
659+
case ME_FUNCTION7:
660+
case ME_CLOSURE0:
661+
case ME_CLOSURE1:
662+
case ME_CLOSURE2:
663+
case ME_CLOSURE3:
664+
case ME_CLOSURE4:
665+
case ME_CLOSURE5:
666+
case ME_CLOSURE6:
667+
case ME_CLOSURE7:
668+
{
669+
const int arity = ARITY(n->type);
670+
for (int i = 0; i < arity; i++) {
671+
if (!reduction_usage_is_valid((const me_expr*)n->parameters[i])) {
672+
return false;
673+
}
674+
}
675+
return true;
676+
}
677+
default:
678+
return true;
641679
}
642-
return TYPE_MASK(arg->type) == ME_VARIABLE || TYPE_MASK(arg->type) == ME_CONSTANT;
643680
}
644681

645682
/* Infer computation type from expression tree (for evaluation) */
@@ -3698,7 +3735,12 @@ typedef float (*me_fun1_f32)(float);
36983735
SQRT_FUNC, SIN_FUNC, COS_FUNC, EXP_FUNC, LOG_FUNC, FABS_FUNC, POW_FUNC, \
36993736
VEC_CONJ) \
37003737
static void me_eval_##SUFFIX(const me_expr *n) { \
3701-
if (!n || !n->output || n->nitems <= 0) return; \
3738+
if (!n || !n->output) return; \
3739+
if (is_reduction_node(n)) { \
3740+
eval_reduction(n, n->nitems); \
3741+
return; \
3742+
} \
3743+
if (n->nitems <= 0) return; \
37023744
\
37033745
int i, j; \
37043746
const int arity = ARITY(n->type); \
@@ -4364,14 +4406,139 @@ static bool all_variables_match_type(const me_expr* n, me_dtype target_type) {
43644406
return true;
43654407
}
43664408

4367-
static void eval_reduction(const me_expr* n) {
4409+
static void broadcast_reduction_output(void* output, me_dtype dtype, int output_nitems) {
4410+
if (!output || output_nitems <= 1) return;
4411+
switch (dtype) {
4412+
case ME_BOOL:
4413+
{
4414+
bool val = ((bool*)output)[0];
4415+
for (int i = 1; i < output_nitems; i++) {
4416+
((bool*)output)[i] = val;
4417+
}
4418+
break;
4419+
}
4420+
case ME_INT8:
4421+
{
4422+
int8_t val = ((int8_t*)output)[0];
4423+
for (int i = 1; i < output_nitems; i++) {
4424+
((int8_t*)output)[i] = val;
4425+
}
4426+
break;
4427+
}
4428+
case ME_INT16:
4429+
{
4430+
int16_t val = ((int16_t*)output)[0];
4431+
for (int i = 1; i < output_nitems; i++) {
4432+
((int16_t*)output)[i] = val;
4433+
}
4434+
break;
4435+
}
4436+
case ME_INT32:
4437+
{
4438+
int32_t val = ((int32_t*)output)[0];
4439+
for (int i = 1; i < output_nitems; i++) {
4440+
((int32_t*)output)[i] = val;
4441+
}
4442+
break;
4443+
}
4444+
case ME_INT64:
4445+
{
4446+
int64_t val = ((int64_t*)output)[0];
4447+
for (int i = 1; i < output_nitems; i++) {
4448+
((int64_t*)output)[i] = val;
4449+
}
4450+
break;
4451+
}
4452+
case ME_UINT8:
4453+
{
4454+
uint8_t val = ((uint8_t*)output)[0];
4455+
for (int i = 1; i < output_nitems; i++) {
4456+
((uint8_t*)output)[i] = val;
4457+
}
4458+
break;
4459+
}
4460+
case ME_UINT16:
4461+
{
4462+
uint16_t val = ((uint16_t*)output)[0];
4463+
for (int i = 1; i < output_nitems; i++) {
4464+
((uint16_t*)output)[i] = val;
4465+
}
4466+
break;
4467+
}
4468+
case ME_UINT32:
4469+
{
4470+
uint32_t val = ((uint32_t*)output)[0];
4471+
for (int i = 1; i < output_nitems; i++) {
4472+
((uint32_t*)output)[i] = val;
4473+
}
4474+
break;
4475+
}
4476+
case ME_UINT64:
4477+
{
4478+
uint64_t val = ((uint64_t*)output)[0];
4479+
for (int i = 1; i < output_nitems; i++) {
4480+
((uint64_t*)output)[i] = val;
4481+
}
4482+
break;
4483+
}
4484+
case ME_FLOAT32:
4485+
{
4486+
float val = ((float*)output)[0];
4487+
for (int i = 1; i < output_nitems; i++) {
4488+
((float*)output)[i] = val;
4489+
}
4490+
break;
4491+
}
4492+
case ME_FLOAT64:
4493+
{
4494+
double val = ((double*)output)[0];
4495+
for (int i = 1; i < output_nitems; i++) {
4496+
((double*)output)[i] = val;
4497+
}
4498+
break;
4499+
}
4500+
case ME_COMPLEX64:
4501+
{
4502+
float _Complex val = ((float _Complex*)output)[0];
4503+
for (int i = 1; i < output_nitems; i++) {
4504+
((float _Complex*)output)[i] = val;
4505+
}
4506+
break;
4507+
}
4508+
case ME_COMPLEX128:
4509+
{
4510+
double _Complex val = ((double _Complex*)output)[0];
4511+
for (int i = 1; i < output_nitems; i++) {
4512+
((double _Complex*)output)[i] = val;
4513+
}
4514+
break;
4515+
}
4516+
default:
4517+
break;
4518+
}
4519+
}
4520+
4521+
static void eval_reduction(const me_expr* n, int output_nitems) {
43684522
if (!n || !n->output || !is_reduction_node(n)) return;
4523+
if (output_nitems <= 0) return;
43694524

43704525
me_expr* arg = (me_expr*)n->parameters[0];
43714526
if (!arg) return;
43724527

43734528
const int nitems = n->nitems;
43744529
me_dtype arg_type = arg->dtype;
4530+
if (arg->type != ME_CONSTANT && arg->type != ME_VARIABLE) {
4531+
arg_type = infer_output_type(arg);
4532+
if (nitems > 0) {
4533+
if (!arg->output) {
4534+
arg->output = malloc((size_t)nitems * dtype_size(arg_type));
4535+
if (!arg->output) return;
4536+
}
4537+
arg->nitems = nitems;
4538+
arg->dtype = arg_type;
4539+
private_eval(arg);
4540+
}
4541+
}
43754542
me_dtype result_type = reduction_output_dtype(arg_type, n->function);
43764543
me_dtype output_type = n->dtype;
43774544
bool is_prod = n->function == (void*)prod_reduce;
@@ -4383,7 +4550,7 @@ static void eval_reduction(const me_expr* n) {
43834550
void* write_ptr = n->output;
43844551
void* temp_output = NULL;
43854552
if (output_type != result_type) {
4386-
temp_output = malloc(dtype_size(result_type));
4553+
temp_output = malloc((size_t)output_nitems * dtype_size(result_type));
43874554
if (!temp_output) return;
43884555
write_ptr = temp_output;
43894556
}
@@ -4627,7 +4794,13 @@ static void eval_reduction(const me_expr* n) {
46274794
}
46284795
}
46294796
}
4630-
else if (arg->type == ME_VARIABLE) {
4797+
else {
4798+
const void* saved_bound = arg->bound;
4799+
int saved_type = arg->type;
4800+
if (arg->type != ME_VARIABLE) {
4801+
((me_expr*)arg)->bound = arg->output;
4802+
((me_expr*)arg)->type = ME_VARIABLE;
4803+
}
46314804
switch (arg_type) {
46324805
case ME_BOOL:
46334806
{
@@ -5140,12 +5313,21 @@ static void eval_reduction(const me_expr* n) {
51405313
default:
51415314
break;
51425315
}
5316+
if (saved_type != ME_VARIABLE) {
5317+
((me_expr*)arg)->bound = saved_bound;
5318+
((me_expr*)arg)->type = saved_type;
5319+
}
5320+
}
5321+
5322+
{
5323+
me_dtype write_type = temp_output ? result_type : output_type;
5324+
broadcast_reduction_output(write_ptr, write_type, output_nitems);
51435325
}
51445326

51455327
if (temp_output) {
51465328
convert_func_t conv = get_convert_func(result_type, output_type);
51475329
if (conv) {
5148-
conv(temp_output, n->output, 1);
5330+
conv(temp_output, n->output, output_nitems);
51495331
}
51505332
free(temp_output);
51515333
}
@@ -5155,7 +5337,7 @@ static void private_eval(const me_expr* n) {
51555337
if (!n) return;
51565338

51575339
if (is_reduction_node(n)) {
5158-
eval_reduction(n);
5340+
eval_reduction(n, 1);
51595341
return;
51605342
}
51615343

0 commit comments

Comments
 (0)