@@ -378,93 +378,137 @@ TEST_F(MslASTPrinterTest, WorkgroupMatrix_Multiples) {
378378 EXPECT_EQ (gen.Result (), R"( #include <metal_stdlib>
379379
380380using namespace metal;
381+
382+ template<typename T, size_t N>
383+ struct tint_array {
384+ const constant T& operator[](size_t i) const constant { return elements[i]; }
385+ device T& operator[](size_t i) device { return elements[i]; }
386+ const device T& operator[](size_t i) const device { return elements[i]; }
387+ thread T& operator[](size_t i) thread { return elements[i]; }
388+ const thread T& operator[](size_t i) const thread { return elements[i]; }
389+ threadgroup T& operator[](size_t i) threadgroup { return elements[i]; }
390+ const threadgroup T& operator[](size_t i) const threadgroup { return elements[i]; }
391+ T elements[N];
392+ };
393+
381394struct tint_symbol_16 {
382395 float2x2 m1;
383- float2x3 m2;
384396 float2x4 m3;
385397};
386398
387399struct tint_symbol_24 {
388400 float3x2 m4;
389- float3x3 m5;
390401 float3x4 m6;
391402};
392403
393404struct tint_symbol_32 {
394405 float4x2 m7;
395- float4x3 m8;
396406 float4x4 m9;
397407};
398408
399- void tint_zero_workgroup_memory(uint local_idx, threadgroup float2x2* const tint_symbol, threadgroup float2x3* const tint_symbol_1, threadgroup float2x4* const tint_symbol_2) {
409+ struct tint_packed_vec3_f32_array_element {
410+ packed_float3 elements;
411+ };
412+
413+ float2x3 tint_unpack_vec3_in_composite(tint_array<tint_packed_vec3_f32_array_element, 2> in) {
414+ float2x3 result = float2x3(float3(in[0].elements), float3(in[1].elements));
415+ return result;
416+ }
417+
418+ float3x3 tint_unpack_vec3_in_composite_1(tint_array<tint_packed_vec3_f32_array_element, 3> in) {
419+ float3x3 result = float3x3(float3(in[0].elements), float3(in[1].elements), float3(in[2].elements));
420+ return result;
421+ }
422+
423+ float4x3 tint_unpack_vec3_in_composite_2(tint_array<tint_packed_vec3_f32_array_element, 4> in) {
424+ float4x3 result = float4x3(float3(in[0].elements), float3(in[1].elements), float3(in[2].elements), float3(in[3].elements));
425+ return result;
426+ }
427+
428+ tint_array<tint_packed_vec3_f32_array_element, 2> tint_pack_vec3_in_composite(float2x3 in) {
429+ tint_array<tint_packed_vec3_f32_array_element, 2> result = tint_array<tint_packed_vec3_f32_array_element, 2>{{.elements=packed_float3(in[0])}, {.elements=packed_float3(in[1])}};
430+ return result;
431+ }
432+
433+ tint_array<tint_packed_vec3_f32_array_element, 3> tint_pack_vec3_in_composite_1(float3x3 in) {
434+ tint_array<tint_packed_vec3_f32_array_element, 3> result = tint_array<tint_packed_vec3_f32_array_element, 3>{{.elements=packed_float3(in[0])}, {.elements=packed_float3(in[1])}, {.elements=packed_float3(in[2])}};
435+ return result;
436+ }
437+
438+ tint_array<tint_packed_vec3_f32_array_element, 4> tint_pack_vec3_in_composite_2(float4x3 in) {
439+ tint_array<tint_packed_vec3_f32_array_element, 4> result = tint_array<tint_packed_vec3_f32_array_element, 4>{{.elements=packed_float3(in[0])}, {.elements=packed_float3(in[1])}, {.elements=packed_float3(in[2])}, {.elements=packed_float3(in[3])}};
440+ return result;
441+ }
442+
443+ void tint_zero_workgroup_memory(uint local_idx, threadgroup float2x2* const tint_symbol, threadgroup tint_array<tint_packed_vec3_f32_array_element, 2>* const tint_symbol_1, threadgroup float2x4* const tint_symbol_2) {
400444 if ((local_idx < 1u)) {
401445 *(tint_symbol) = float2x2(float2(0.0f), float2(0.0f));
402- *(tint_symbol_1) = float2x3(float3(0.0f), float3(0.0f));
446+ *(tint_symbol_1) = tint_pack_vec3_in_composite( float2x3(float3(0.0f), float3(0.0f) ));
403447 *(tint_symbol_2) = float2x4(float4(0.0f), float4(0.0f));
404448 }
405449 threadgroup_barrier(mem_flags::mem_threadgroup);
406450}
407451
408- void tint_zero_workgroup_memory_1(uint local_idx_1, threadgroup float3x2* const tint_symbol_3, threadgroup float3x3 * const tint_symbol_4, threadgroup float3x4* const tint_symbol_5) {
452+ void tint_zero_workgroup_memory_1(uint local_idx_1, threadgroup float3x2* const tint_symbol_3, threadgroup tint_array<tint_packed_vec3_f32_array_element, 3> * const tint_symbol_4, threadgroup float3x4* const tint_symbol_5) {
409453 if ((local_idx_1 < 1u)) {
410454 *(tint_symbol_3) = float3x2(float2(0.0f), float2(0.0f), float2(0.0f));
411- *(tint_symbol_4) = float3x3(float3(0.0f), float3(0.0f), float3(0.0f));
455+ *(tint_symbol_4) = tint_pack_vec3_in_composite_1( float3x3(float3(0.0f), float3(0.0f), float3(0.0f) ));
412456 *(tint_symbol_5) = float3x4(float4(0.0f), float4(0.0f), float4(0.0f));
413457 }
414458 threadgroup_barrier(mem_flags::mem_threadgroup);
415459}
416460
417- void tint_zero_workgroup_memory_2(uint local_idx_2, threadgroup float4x2* const tint_symbol_6, threadgroup float4x3 * const tint_symbol_7, threadgroup float4x4* const tint_symbol_8) {
461+ void tint_zero_workgroup_memory_2(uint local_idx_2, threadgroup float4x2* const tint_symbol_6, threadgroup tint_array<tint_packed_vec3_f32_array_element, 4> * const tint_symbol_7, threadgroup float4x4* const tint_symbol_8) {
418462 if ((local_idx_2 < 1u)) {
419463 *(tint_symbol_6) = float4x2(float2(0.0f), float2(0.0f), float2(0.0f), float2(0.0f));
420- *(tint_symbol_7) = float4x3(float3(0.0f), float3(0.0f), float3(0.0f), float3(0.0f));
464+ *(tint_symbol_7) = tint_pack_vec3_in_composite_2( float4x3(float3(0.0f), float3(0.0f), float3(0.0f), float3(0.0f) ));
421465 *(tint_symbol_8) = float4x4(float4(0.0f), float4(0.0f), float4(0.0f), float4(0.0f));
422466 }
423467 threadgroup_barrier(mem_flags::mem_threadgroup);
424468}
425469
426- void main1_inner(uint local_invocation_index, threadgroup float2x2* const tint_symbol_9, threadgroup float2x3 * const tint_symbol_10, threadgroup float2x4* const tint_symbol_11) {
470+ void main1_inner(uint local_invocation_index, threadgroup float2x2* const tint_symbol_9, threadgroup tint_array<tint_packed_vec3_f32_array_element, 2> * const tint_symbol_10, threadgroup float2x4* const tint_symbol_11) {
427471 tint_zero_workgroup_memory(local_invocation_index, tint_symbol_9, tint_symbol_10, tint_symbol_11);
428472 float2x2 const a1 = *(tint_symbol_9);
429- float2x3 const a2 = *(tint_symbol_10);
473+ float2x3 const a2 = tint_unpack_vec3_in_composite( *(tint_symbol_10) );
430474 float2x4 const a3 = *(tint_symbol_11);
431475}
432476
433477kernel void main1(threadgroup tint_symbol_16* tint_symbol_13 [[threadgroup(0)]], uint local_invocation_index [[thread_index_in_threadgroup]]) {
434478 threadgroup float2x2* const tint_symbol_12 = &((*(tint_symbol_13)).m1);
435- threadgroup float2x3* const tint_symbol_14 = &((*(tint_symbol_13)).m2) ;
479+ threadgroup tint_array<tint_packed_vec3_f32_array_element, 2> tint_symbol_14;
436480 threadgroup float2x4* const tint_symbol_15 = &((*(tint_symbol_13)).m3);
437- main1_inner(local_invocation_index, tint_symbol_12, tint_symbol_14, tint_symbol_15);
481+ main1_inner(local_invocation_index, tint_symbol_12, &( tint_symbol_14) , tint_symbol_15);
438482 return;
439483}
440484
441- void main2_inner(uint local_invocation_index_1, threadgroup float3x2* const tint_symbol_17, threadgroup float3x3 * const tint_symbol_18, threadgroup float3x4* const tint_symbol_19) {
485+ void main2_inner(uint local_invocation_index_1, threadgroup float3x2* const tint_symbol_17, threadgroup tint_array<tint_packed_vec3_f32_array_element, 3> * const tint_symbol_18, threadgroup float3x4* const tint_symbol_19) {
442486 tint_zero_workgroup_memory_1(local_invocation_index_1, tint_symbol_17, tint_symbol_18, tint_symbol_19);
443487 float3x2 const a1 = *(tint_symbol_17);
444- float3x3 const a2 = *(tint_symbol_18);
488+ float3x3 const a2 = tint_unpack_vec3_in_composite_1( *(tint_symbol_18) );
445489 float3x4 const a3 = *(tint_symbol_19);
446490}
447491
448492kernel void main2(threadgroup tint_symbol_24* tint_symbol_21 [[threadgroup(0)]], uint local_invocation_index_1 [[thread_index_in_threadgroup]]) {
449493 threadgroup float3x2* const tint_symbol_20 = &((*(tint_symbol_21)).m4);
450- threadgroup float3x3* const tint_symbol_22 = &((*(tint_symbol_21)).m5) ;
494+ threadgroup tint_array<tint_packed_vec3_f32_array_element, 3> tint_symbol_22;
451495 threadgroup float3x4* const tint_symbol_23 = &((*(tint_symbol_21)).m6);
452- main2_inner(local_invocation_index_1, tint_symbol_20, tint_symbol_22, tint_symbol_23);
496+ main2_inner(local_invocation_index_1, tint_symbol_20, &( tint_symbol_22) , tint_symbol_23);
453497 return;
454498}
455499
456- void main3_inner(uint local_invocation_index_2, threadgroup float4x2* const tint_symbol_25, threadgroup float4x3 * const tint_symbol_26, threadgroup float4x4* const tint_symbol_27) {
500+ void main3_inner(uint local_invocation_index_2, threadgroup float4x2* const tint_symbol_25, threadgroup tint_array<tint_packed_vec3_f32_array_element, 4> * const tint_symbol_26, threadgroup float4x4* const tint_symbol_27) {
457501 tint_zero_workgroup_memory_2(local_invocation_index_2, tint_symbol_25, tint_symbol_26, tint_symbol_27);
458502 float4x2 const a1 = *(tint_symbol_25);
459- float4x3 const a2 = *(tint_symbol_26);
503+ float4x3 const a2 = tint_unpack_vec3_in_composite_2( *(tint_symbol_26) );
460504 float4x4 const a3 = *(tint_symbol_27);
461505}
462506
463507kernel void main3(threadgroup tint_symbol_32* tint_symbol_29 [[threadgroup(0)]], uint local_invocation_index_2 [[thread_index_in_threadgroup]]) {
464508 threadgroup float4x2* const tint_symbol_28 = &((*(tint_symbol_29)).m7);
465- threadgroup float4x3* const tint_symbol_30 = &((*(tint_symbol_29)).m8) ;
509+ threadgroup tint_array<tint_packed_vec3_f32_array_element, 4> tint_symbol_30;
466510 threadgroup float4x4* const tint_symbol_31 = &((*(tint_symbol_29)).m9);
467- main3_inner(local_invocation_index_2, tint_symbol_28, tint_symbol_30, tint_symbol_31);
511+ main3_inner(local_invocation_index_2, tint_symbol_28, &( tint_symbol_30) , tint_symbol_31);
468512 return;
469513}
470514
@@ -479,11 +523,11 @@ kernel void main4_no_usages() {
479523 ASSERT_TRUE (allocations.count (" main2" ));
480524 ASSERT_TRUE (allocations.count (" main3" ));
481525 ASSERT_EQ (allocations.at (" main1" ).size (), 1u );
482- EXPECT_EQ (allocations.at (" main1" )[0 ], 20u * sizeof (float ));
526+ EXPECT_EQ (allocations.at (" main1" )[0 ], 12u * sizeof (float ));
483527 ASSERT_EQ (allocations.at (" main2" ).size (), 1u );
484- EXPECT_EQ (allocations.at (" main2" )[0 ], 32u * sizeof (float ));
528+ EXPECT_EQ (allocations.at (" main2" )[0 ], 20u * sizeof (float ));
485529 ASSERT_EQ (allocations.at (" main3" ).size (), 1u );
486- EXPECT_EQ (allocations.at (" main3" )[0 ], 40u * sizeof (float ));
530+ EXPECT_EQ (allocations.at (" main3" )[0 ], 24u * sizeof (float ));
487531 EXPECT_EQ (allocations.at (" main4_no_usages" ).size (), 0u );
488532}
489533
0 commit comments