Skip to content

Commit 5f68dee

Browse files
jrpricemibrunin
authored andcommitted
[Backport] Security bug 378725734 (2/2)
Use packed_vec3 for workgroup storage This makes sure that the threadgroup allocation sizes that Tint reflects to Dawn match the sizes of the types used in the generated MSL shader. Bug: 378725734 Change-Id: Ib67f6d3299e376ca263419245912e8f453b6cb88 Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/215075 Reviewed-by: dan sinclair <dsinclair@chromium.org> Commit-Queue: James Price <jrprice@google.com> (cherry picked from commit c368b05c475b3473276ad41f09c5f1b149df00e8) Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/215937 Auto-Submit: James Price <jrprice@google.com> Reviewed-by: Antonio Maiorano <amaiorano@google.com> Reviewed-on: https://codereview.qt-project.org/c/qt/qtwebengine-chromium/+/611748 Reviewed-by: Anu Aliyas <anu.aliyas@qt.io>
1 parent 81f9952 commit 5f68dee

File tree

2 files changed

+82
-30
lines changed

2 files changed

+82
-30
lines changed

chromium/third_party/dawn/src/tint/lang/msl/writer/ast_printer/ast_printer_test.cc

Lines changed: 68 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -378,93 +378,137 @@ TEST_F(MslASTPrinterTest, WorkgroupMatrix_Multiples) {
378378
EXPECT_EQ(gen.Result(), R"(#include <metal_stdlib>
379379
380380
using namespace metal;
381+
382+
template<typename T, size_t N>
383+
struct tint_array {
384+
const constant T& operator[](size_t i) const constant { return elements[i]; }
385+
device T& operator[](size_t i) device { return elements[i]; }
386+
const device T& operator[](size_t i) const device { return elements[i]; }
387+
thread T& operator[](size_t i) thread { return elements[i]; }
388+
const thread T& operator[](size_t i) const thread { return elements[i]; }
389+
threadgroup T& operator[](size_t i) threadgroup { return elements[i]; }
390+
const threadgroup T& operator[](size_t i) const threadgroup { return elements[i]; }
391+
T elements[N];
392+
};
393+
381394
struct tint_symbol_16 {
382395
float2x2 m1;
383-
float2x3 m2;
384396
float2x4 m3;
385397
};
386398
387399
struct tint_symbol_24 {
388400
float3x2 m4;
389-
float3x3 m5;
390401
float3x4 m6;
391402
};
392403
393404
struct tint_symbol_32 {
394405
float4x2 m7;
395-
float4x3 m8;
396406
float4x4 m9;
397407
};
398408
399-
void tint_zero_workgroup_memory(uint local_idx, threadgroup float2x2* const tint_symbol, threadgroup float2x3* const tint_symbol_1, threadgroup float2x4* const tint_symbol_2) {
409+
struct tint_packed_vec3_f32_array_element {
410+
packed_float3 elements;
411+
};
412+
413+
float2x3 tint_unpack_vec3_in_composite(tint_array<tint_packed_vec3_f32_array_element, 2> in) {
414+
float2x3 result = float2x3(float3(in[0].elements), float3(in[1].elements));
415+
return result;
416+
}
417+
418+
float3x3 tint_unpack_vec3_in_composite_1(tint_array<tint_packed_vec3_f32_array_element, 3> in) {
419+
float3x3 result = float3x3(float3(in[0].elements), float3(in[1].elements), float3(in[2].elements));
420+
return result;
421+
}
422+
423+
float4x3 tint_unpack_vec3_in_composite_2(tint_array<tint_packed_vec3_f32_array_element, 4> in) {
424+
float4x3 result = float4x3(float3(in[0].elements), float3(in[1].elements), float3(in[2].elements), float3(in[3].elements));
425+
return result;
426+
}
427+
428+
tint_array<tint_packed_vec3_f32_array_element, 2> tint_pack_vec3_in_composite(float2x3 in) {
429+
tint_array<tint_packed_vec3_f32_array_element, 2> result = tint_array<tint_packed_vec3_f32_array_element, 2>{{.elements=packed_float3(in[0])}, {.elements=packed_float3(in[1])}};
430+
return result;
431+
}
432+
433+
tint_array<tint_packed_vec3_f32_array_element, 3> tint_pack_vec3_in_composite_1(float3x3 in) {
434+
tint_array<tint_packed_vec3_f32_array_element, 3> result = tint_array<tint_packed_vec3_f32_array_element, 3>{{.elements=packed_float3(in[0])}, {.elements=packed_float3(in[1])}, {.elements=packed_float3(in[2])}};
435+
return result;
436+
}
437+
438+
tint_array<tint_packed_vec3_f32_array_element, 4> tint_pack_vec3_in_composite_2(float4x3 in) {
439+
tint_array<tint_packed_vec3_f32_array_element, 4> result = tint_array<tint_packed_vec3_f32_array_element, 4>{{.elements=packed_float3(in[0])}, {.elements=packed_float3(in[1])}, {.elements=packed_float3(in[2])}, {.elements=packed_float3(in[3])}};
440+
return result;
441+
}
442+
443+
void tint_zero_workgroup_memory(uint local_idx, threadgroup float2x2* const tint_symbol, threadgroup tint_array<tint_packed_vec3_f32_array_element, 2>* const tint_symbol_1, threadgroup float2x4* const tint_symbol_2) {
400444
if ((local_idx < 1u)) {
401445
*(tint_symbol) = float2x2(float2(0.0f), float2(0.0f));
402-
*(tint_symbol_1) = float2x3(float3(0.0f), float3(0.0f));
446+
*(tint_symbol_1) = tint_pack_vec3_in_composite(float2x3(float3(0.0f), float3(0.0f)));
403447
*(tint_symbol_2) = float2x4(float4(0.0f), float4(0.0f));
404448
}
405449
threadgroup_barrier(mem_flags::mem_threadgroup);
406450
}
407451
408-
void tint_zero_workgroup_memory_1(uint local_idx_1, threadgroup float3x2* const tint_symbol_3, threadgroup float3x3* const tint_symbol_4, threadgroup float3x4* const tint_symbol_5) {
452+
void tint_zero_workgroup_memory_1(uint local_idx_1, threadgroup float3x2* const tint_symbol_3, threadgroup tint_array<tint_packed_vec3_f32_array_element, 3>* const tint_symbol_4, threadgroup float3x4* const tint_symbol_5) {
409453
if ((local_idx_1 < 1u)) {
410454
*(tint_symbol_3) = float3x2(float2(0.0f), float2(0.0f), float2(0.0f));
411-
*(tint_symbol_4) = float3x3(float3(0.0f), float3(0.0f), float3(0.0f));
455+
*(tint_symbol_4) = tint_pack_vec3_in_composite_1(float3x3(float3(0.0f), float3(0.0f), float3(0.0f)));
412456
*(tint_symbol_5) = float3x4(float4(0.0f), float4(0.0f), float4(0.0f));
413457
}
414458
threadgroup_barrier(mem_flags::mem_threadgroup);
415459
}
416460
417-
void tint_zero_workgroup_memory_2(uint local_idx_2, threadgroup float4x2* const tint_symbol_6, threadgroup float4x3* const tint_symbol_7, threadgroup float4x4* const tint_symbol_8) {
461+
void tint_zero_workgroup_memory_2(uint local_idx_2, threadgroup float4x2* const tint_symbol_6, threadgroup tint_array<tint_packed_vec3_f32_array_element, 4>* const tint_symbol_7, threadgroup float4x4* const tint_symbol_8) {
418462
if ((local_idx_2 < 1u)) {
419463
*(tint_symbol_6) = float4x2(float2(0.0f), float2(0.0f), float2(0.0f), float2(0.0f));
420-
*(tint_symbol_7) = float4x3(float3(0.0f), float3(0.0f), float3(0.0f), float3(0.0f));
464+
*(tint_symbol_7) = tint_pack_vec3_in_composite_2(float4x3(float3(0.0f), float3(0.0f), float3(0.0f), float3(0.0f)));
421465
*(tint_symbol_8) = float4x4(float4(0.0f), float4(0.0f), float4(0.0f), float4(0.0f));
422466
}
423467
threadgroup_barrier(mem_flags::mem_threadgroup);
424468
}
425469
426-
void main1_inner(uint local_invocation_index, threadgroup float2x2* const tint_symbol_9, threadgroup float2x3* const tint_symbol_10, threadgroup float2x4* const tint_symbol_11) {
470+
void main1_inner(uint local_invocation_index, threadgroup float2x2* const tint_symbol_9, threadgroup tint_array<tint_packed_vec3_f32_array_element, 2>* const tint_symbol_10, threadgroup float2x4* const tint_symbol_11) {
427471
tint_zero_workgroup_memory(local_invocation_index, tint_symbol_9, tint_symbol_10, tint_symbol_11);
428472
float2x2 const a1 = *(tint_symbol_9);
429-
float2x3 const a2 = *(tint_symbol_10);
473+
float2x3 const a2 = tint_unpack_vec3_in_composite(*(tint_symbol_10));
430474
float2x4 const a3 = *(tint_symbol_11);
431475
}
432476
433477
kernel void main1(threadgroup tint_symbol_16* tint_symbol_13 [[threadgroup(0)]], uint local_invocation_index [[thread_index_in_threadgroup]]) {
434478
threadgroup float2x2* const tint_symbol_12 = &((*(tint_symbol_13)).m1);
435-
threadgroup float2x3* const tint_symbol_14 = &((*(tint_symbol_13)).m2);
479+
threadgroup tint_array<tint_packed_vec3_f32_array_element, 2> tint_symbol_14;
436480
threadgroup float2x4* const tint_symbol_15 = &((*(tint_symbol_13)).m3);
437-
main1_inner(local_invocation_index, tint_symbol_12, tint_symbol_14, tint_symbol_15);
481+
main1_inner(local_invocation_index, tint_symbol_12, &(tint_symbol_14), tint_symbol_15);
438482
return;
439483
}
440484
441-
void main2_inner(uint local_invocation_index_1, threadgroup float3x2* const tint_symbol_17, threadgroup float3x3* const tint_symbol_18, threadgroup float3x4* const tint_symbol_19) {
485+
void main2_inner(uint local_invocation_index_1, threadgroup float3x2* const tint_symbol_17, threadgroup tint_array<tint_packed_vec3_f32_array_element, 3>* const tint_symbol_18, threadgroup float3x4* const tint_symbol_19) {
442486
tint_zero_workgroup_memory_1(local_invocation_index_1, tint_symbol_17, tint_symbol_18, tint_symbol_19);
443487
float3x2 const a1 = *(tint_symbol_17);
444-
float3x3 const a2 = *(tint_symbol_18);
488+
float3x3 const a2 = tint_unpack_vec3_in_composite_1(*(tint_symbol_18));
445489
float3x4 const a3 = *(tint_symbol_19);
446490
}
447491
448492
kernel void main2(threadgroup tint_symbol_24* tint_symbol_21 [[threadgroup(0)]], uint local_invocation_index_1 [[thread_index_in_threadgroup]]) {
449493
threadgroup float3x2* const tint_symbol_20 = &((*(tint_symbol_21)).m4);
450-
threadgroup float3x3* const tint_symbol_22 = &((*(tint_symbol_21)).m5);
494+
threadgroup tint_array<tint_packed_vec3_f32_array_element, 3> tint_symbol_22;
451495
threadgroup float3x4* const tint_symbol_23 = &((*(tint_symbol_21)).m6);
452-
main2_inner(local_invocation_index_1, tint_symbol_20, tint_symbol_22, tint_symbol_23);
496+
main2_inner(local_invocation_index_1, tint_symbol_20, &(tint_symbol_22), tint_symbol_23);
453497
return;
454498
}
455499
456-
void main3_inner(uint local_invocation_index_2, threadgroup float4x2* const tint_symbol_25, threadgroup float4x3* const tint_symbol_26, threadgroup float4x4* const tint_symbol_27) {
500+
void main3_inner(uint local_invocation_index_2, threadgroup float4x2* const tint_symbol_25, threadgroup tint_array<tint_packed_vec3_f32_array_element, 4>* const tint_symbol_26, threadgroup float4x4* const tint_symbol_27) {
457501
tint_zero_workgroup_memory_2(local_invocation_index_2, tint_symbol_25, tint_symbol_26, tint_symbol_27);
458502
float4x2 const a1 = *(tint_symbol_25);
459-
float4x3 const a2 = *(tint_symbol_26);
503+
float4x3 const a2 = tint_unpack_vec3_in_composite_2(*(tint_symbol_26));
460504
float4x4 const a3 = *(tint_symbol_27);
461505
}
462506
463507
kernel void main3(threadgroup tint_symbol_32* tint_symbol_29 [[threadgroup(0)]], uint local_invocation_index_2 [[thread_index_in_threadgroup]]) {
464508
threadgroup float4x2* const tint_symbol_28 = &((*(tint_symbol_29)).m7);
465-
threadgroup float4x3* const tint_symbol_30 = &((*(tint_symbol_29)).m8);
509+
threadgroup tint_array<tint_packed_vec3_f32_array_element, 4> tint_symbol_30;
466510
threadgroup float4x4* const tint_symbol_31 = &((*(tint_symbol_29)).m9);
467-
main3_inner(local_invocation_index_2, tint_symbol_28, tint_symbol_30, tint_symbol_31);
511+
main3_inner(local_invocation_index_2, tint_symbol_28, &(tint_symbol_30), tint_symbol_31);
468512
return;
469513
}
470514
@@ -479,11 +523,11 @@ kernel void main4_no_usages() {
479523
ASSERT_TRUE(allocations.count("main2"));
480524
ASSERT_TRUE(allocations.count("main3"));
481525
ASSERT_EQ(allocations.at("main1").size(), 1u);
482-
EXPECT_EQ(allocations.at("main1")[0], 20u * sizeof(float));
526+
EXPECT_EQ(allocations.at("main1")[0], 12u * sizeof(float));
483527
ASSERT_EQ(allocations.at("main2").size(), 1u);
484-
EXPECT_EQ(allocations.at("main2")[0], 32u * sizeof(float));
528+
EXPECT_EQ(allocations.at("main2")[0], 20u * sizeof(float));
485529
ASSERT_EQ(allocations.at("main3").size(), 1u);
486-
EXPECT_EQ(allocations.at("main3")[0], 40u * sizeof(float));
530+
EXPECT_EQ(allocations.at("main3")[0], 24u * sizeof(float));
487531
EXPECT_EQ(allocations.at("main4_no_usages").size(), 0u);
488532
}
489533

chromium/third_party/dawn/src/tint/lang/msl/writer/ast_raise/packed_vec3.cc

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,14 @@ struct PackedVec3::State {
8383
/// A map from type to the name of a helper function used to unpack that type.
8484
Hashmap<const core::type::Type*, Symbol, 4> unpack_helpers;
8585

86+
/// @returns true if @p addrspace requires vec3 types to be packed
87+
bool AddressSpaceNeedsPacking(core::AddressSpace addrspace) {
88+
// Host-shareable address spaces need to be packed to match the memory layout on the host.
89+
// The workgroup address space needs to be packed so that the size of generated threadgroup
90+
// variables matches the size of the original WGSL declarations.
91+
return core::IsHostShareable(addrspace) || addrspace == core::AddressSpace::kWorkgroup;
92+
}
93+
8694
/// @param ty the type to test
8795
/// @returns true if `ty` is a vec3, false otherwise
8896
bool IsVec3(const core::type::Type* ty) {
@@ -373,7 +381,7 @@ struct PackedVec3::State {
373381
// if the transform is necessary.
374382
for (auto* decl : src.AST().GlobalVariables()) {
375383
auto* var = sem.Get<sem::GlobalVariable>(decl);
376-
if (var && core::IsHostShareable(var->AddressSpace()) &&
384+
if (var && AddressSpaceNeedsPacking(var->AddressSpace()) &&
377385
ContainsVec3(var->Type()->UnwrapRef())) {
378386
return true;
379387
}
@@ -410,7 +418,7 @@ struct PackedVec3::State {
410418
[&](const sem::TypeExpression* type) {
411419
// Rewrite pointers to types that contain vec3s.
412420
auto* ptr = type->Type()->As<core::type::Pointer>();
413-
if (ptr && core::IsHostShareable(ptr->AddressSpace())) {
421+
if (ptr && AddressSpaceNeedsPacking(ptr->AddressSpace())) {
414422
auto new_store_type = RewriteType(ptr->StoreType());
415423
if (new_store_type) {
416424
auto access = ptr->AddressSpace() == core::AddressSpace::kStorage
@@ -423,7 +431,7 @@ struct PackedVec3::State {
423431
}
424432
},
425433
[&](const sem::Variable* var) {
426-
if (!core::IsHostShareable(var->AddressSpace())) {
434+
if (!AddressSpaceNeedsPacking(var->AddressSpace())) {
427435
return;
428436
}
429437

@@ -439,7 +447,7 @@ struct PackedVec3::State {
439447
auto* lhs = sem.GetVal(assign->lhs);
440448
auto* rhs = sem.GetVal(assign->rhs);
441449
if (!ContainsVec3(rhs->Type()) ||
442-
!core::IsHostShareable(
450+
!AddressSpaceNeedsPacking(
443451
lhs->Type()->As<core::type::Reference>()->AddressSpace())) {
444452
// Skip assignments to address spaces that are not host-shareable, or
445453
// that do not contain vec3 types.
@@ -467,7 +475,7 @@ struct PackedVec3::State {
467475
[&](const sem::Load* load) {
468476
// Unpack loads of types that contain vec3s in host-shareable address spaces.
469477
if (ContainsVec3(load->Type()) &&
470-
core::IsHostShareable(load->MemoryView()->AddressSpace())) {
478+
AddressSpaceNeedsPacking(load->MemoryView()->AddressSpace())) {
471479
to_unpack.Add(load);
472480
}
473481
},
@@ -477,7 +485,7 @@ struct PackedVec3::State {
477485
// struct.
478486
if (auto* ref = accessor->Type()->As<core::type::Reference>()) {
479487
if (IsVec3(ref->StoreType()) &&
480-
core::IsHostShareable(ref->AddressSpace())) {
488+
AddressSpaceNeedsPacking(ref->AddressSpace())) {
481489
ctx.Replace(node, b.MemberAccessor(ctx.Clone(accessor->Declaration()),
482490
kStructMemberName));
483491
}

0 commit comments

Comments
 (0)