Skip to content

jax.errors.ConcretizationTypeError in Splash attention during model.generate #21916

@pctablet505

Description

@pctablet505

Description

When calling model.generate, a jax.errors.ConcretizationTypeError is thrown from the Splash attention kernel. The model then falls back to the native JAX dot_product_attention.

Although the fallback mechanism works, this error generates a long traceback for every single call to model.generate. This is filling up log files, making them noisy and difficult to inspect.

The error appears to be a JAX tracing issue where a concrete value is expected but an abstract tracer value is received within the hashing function of the attention mask.

Traceback

this is the causing file
keras/src/backend/jax/nn.py

Failed to apply Splash kernel for flash attention. Falling back to JAX native dot_product_attention.
Traceback (most recent call last):
  File "/export/hda3/borglet/local_ram_fs_dirs/0.hellorahul_group_212967041.1.benchmarking_benchmark.hellorahul.304686816269.8f4f9f98cb7a3055/mpm_7141124cf57a55fcacd0972eb8bff5af_c73bae2f6e404030bf1737ea293b9ebf/benchmark.runfiles/google3/third_party/py/keras/src/backend/jax/nn.py", line 1344, in dot_product_attention
    output = wrap_flash_attention(
             ^^^^^^^^^^^^^^^^^^^^^
  File "/export/hda3/borglet/local_ram_fs_dirs/0.hellorahul_group_212967041.1.benchmarking_benchmark.hellorahul.304686816269.8f4f9f98cb7a3055/mpm_7141124cf57a55fcacd0972eb8bff5af_c73bae2f6e404030bf1737ea293b9ebf/benchmark.runfiles/google3/third_party/py/keras/src/backend/jax/nn.py", line 1196, in wrap_flash_attention
    splash_kernel = splash_attention_kernel.make_splash_mha(
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/export/hda3/borglet/local_ram_fs_dirs/0.hellorahul_group_212967041.1.benchmarking_benchmark.hellorahul.304686816269.8f4f9f98cb7a3055/mpm_7141124cf57a55fcacd0972eb8bff5af_c73bae2f6e404030bf1737ea293b9ebf/benchmark.runfiles/google3/third_party/py/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py", line 2557, in _make_splash_attention
    fwd_mask_info, mask_function_fwd = process_mask_fn(
                                       ^^^^^^^^^^^^^^^^
  File "/export/hda3/borglet/local_ram_fs_dirs/0.hellorahul_group_212967041.1.benchmarking_benchmark.hellorahul.304686816269.8f4f9f98cb7a3055/mpm_7141124cf57a55fcacd0972eb8bff5af_c73bae2f6e404030bf1737ea293b9ebf/benchmark.runfiles/google3/third_party/py/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_mask.py", line 222, in __hash__
    return hash((type(self),) + tuple(hash(mask) for mask in self.masks))
                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/export/hda3/borglet/local_ram_fs_dirs/0.hellorahul_group_212967041.1.benchmarking_benchmark.hellorahul.304686816269.8f4f9f98cb7a3055/mpm_7141124cf57a55fcacd0972eb8bff5af_c73bae2f6e404030bf1737ea293b9ebf/benchmark.runfiles/google3/third_party/py/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_mask.py", line 222, in <genexpr>
    return hash((type(self),) + tuple(hash(mask) for mask in self.masks))
                                      ^^^^^^^^^^
  File "/export/hda3/borglet/local_ram_fs_dirs/0.hellorahul_group_212967041.1.benchmarking_benchmark.hellorahul.304686816269.8f4f9f98cb7a3055/mpm_7141124cf57a55fcacd0972eb8bff5af_c73bae2f6e404030bf1737ea293b9ebf/benchmark.runfiles/google3/third_party/py/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_mask.py", line 516, in __hash__
    return hash((type(self), self.array.tobytes()))
                             ^^^^^^^^^^^^^^^^^^^^
  File "/export/hda3/borglet/local_ram_fs_dirs/0.hellorahul_group_212967041.1.benchmarking_benchmark.hellorahul.304686816269.8f4f9f98cb7a3055/mpm_7141124cf57a55fcacd0972eb8bff5af_c73bae2f6e404030bf1737ea293b9ebf/benchmark.runfiles/google3/third_party/py/jax/_src/core.py", line 953, in tobytes
    raise ConcretizationTypeError(self,
jax.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape bool[1,250]
The tobytes() method was called on traced array with shape bool[1,250].
The error occurred while tracing the function _body at /export/hda3/borglet/local_ram_fs_dirs/0.hellorahul_group_212967041.1.benchmarking_benchmark.hellorahul.304686816269.8f4f9f98cb7a3055/mpm_7141124cf57a55fcacd0972eb8bff5af_c73bae2f6e404030bf1737ea293b9ebf/benchmark.runfiles/google3/third_party/py/keras/src/backend/jax/core.py:474 for while_body. This concrete value was not available in Python because it depends on the value of the argument args[3].
The error occurred while tracing the function _body at /export/hda3/borglet/local_ram_fs_dirs/0.hellorahul_group_212967041.1.benchmarking_benchmark.hellorahul.304686816269.8f4f9f98cb7a3055/mpm_7141124cf57a55fcacd0972eb8bff5af_c73bae2f6e404030bf1737ea293b9ebf/benchmark.runfiles/google3/third_party/py/keras/src/backend/jax/core.py:474 for while_body. This concrete value was not available in Python because it depends on the value of the argument args[3].

See [https://docs.jax.dev/en/latest/errors.html#jax.errors.ConcretizationTypeError](https://www.google.com/url?sa=D&q=https%3A%2F%2Fdocs.jax.dev%2Fen%2Flatest%2Ferrors.html%23jax.errors.ConcretizationTypeError)

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions