Skip to content

Commit 60abda5

Browse files
authored
feat: select vulkan device with env variable (#629)
1 parent 23fce0b commit 60abda5

File tree

1 file changed

+21
-1
lines changed

1 file changed

+21
-1
lines changed

stable-diffusion.cpp

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,27 @@ class StableDiffusionGGML {
165165
#endif
166166
#ifdef SD_USE_VULKAN
167167
LOG_DEBUG("Using Vulkan backend");
168-
for (int device = 0; device < ggml_backend_vk_get_device_count(); ++device) {
168+
size_t device = 0;
169+
const int device_count = ggml_backend_vk_get_device_count();
170+
if (device_count) {
171+
const char* SD_VK_DEVICE = getenv("SD_VK_DEVICE");
172+
if (SD_VK_DEVICE != nullptr) {
173+
std::string sd_vk_device_str = SD_VK_DEVICE;
174+
try {
175+
device = std::stoull(sd_vk_device_str);
176+
} catch (const std::invalid_argument&) {
177+
LOG_WARN("SD_VK_DEVICE environment variable is not a valid integer (%s). Falling back to device 0.", SD_VK_DEVICE);
178+
device = 0;
179+
} catch (const std::out_of_range&) {
180+
LOG_WARN("SD_VK_DEVICE environment variable value is out of range for `unsigned long long` type (%s). Falling back to device 0.", SD_VK_DEVICE);
181+
device = 0;
182+
}
183+
if (device >= device_count) {
184+
LOG_WARN("Cannot find targeted vulkan device (%llu). Falling back to device 0.", device);
185+
device = 0;
186+
}
187+
}
188+
LOG_INFO("Vulkan: Using device %llu", device);
169189
backend = ggml_backend_vk_init(device);
170190
}
171191
if (!backend) {

0 commit comments

Comments
 (0)