libs/vulkan-compute/src/vk_diversity.c
1/* vk_diversity.c - Diversity sequence generation implementation
2 *
3 * Implements GPU-accelerated diversity sequence generation using
4 * iterative centroid-based maximum distance selection.
5 */
6
7#include "vk_diversity.h"
8#include <stdlib.h>
9#include <string.h>
10#include <stdio.h>
11
12struct VkDiversityContext {
13 VkComputeContext* ctx;
14
15 /* Dataset parameters */
16 uint32_t num_poems;
17 uint32_t embedding_dim;
18
19 /* GPU buffers */
20 VkComputeBuffer* embeddings_buf; /* All embeddings (device-local) */
21 VkComputeBuffer* centroid_buf; /* Current centroid (device-local) */
22 VkComputeBuffer* distances_buf; /* Distance values (device-local) */
23 VkComputeBuffer* mask_buf; /* Selection mask (device-local) */
24 VkComputeBuffer* result_buf; /* Max reduction result (host-visible) */
25
26 /* Pipelines */
27 VkComputePipeline* cosine_pipeline;
28 VkComputePipeline* centroid_pipeline;
29 VkComputePipeline* reduction_pipeline;
30
31 /* Host-side scratch buffers */
32 uint32_t* mask; /* CPU copy of mask */
33 float* embeddings; /* CPU copy of embeddings for centroid init */
34};
35
36/* {{{ vkd_init
37 */
38
39VkDiversityContext* vkd_init(VkComputeContext* ctx,
40 const float* embeddings,
41 uint32_t num_poems,
42 uint32_t embedding_dim) {
43 if (!ctx || !embeddings || num_poems == 0 || embedding_dim == 0) {
44 return NULL;
45 }
46
47 VkDiversityContext* div_ctx = calloc(1, sizeof(VkDiversityContext));
48 if (!div_ctx) {
49 return NULL;
50 }
51
52 div_ctx->ctx = ctx;
53 div_ctx->num_poems = num_poems;
54 div_ctx->embedding_dim = embedding_dim;
55
56 printf("[VKD] Initializing diversity context...\n");
57 printf(" Poems: %u, Dimensions: %u\n", num_poems, embedding_dim);
58
59 /* Create GPU buffers */
60 size_t embeddings_size = num_poems * embedding_dim * sizeof(float);
61 size_t centroid_size = embedding_dim * sizeof(float);
62 size_t distances_size = num_poems * sizeof(float);
63 size_t mask_size = num_poems * sizeof(uint32_t);
64 size_t result_size = 2 * sizeof(uint32_t); /* max_index + max_distance */
65
66 div_ctx->embeddings_buf = vkc_create_buffer(ctx, embeddings_size, VKC_BUFFER_DEVICE_LOCAL);
67 div_ctx->centroid_buf = vkc_create_buffer(ctx, centroid_size, VKC_BUFFER_DEVICE_LOCAL);
68 div_ctx->distances_buf = vkc_create_buffer(ctx, distances_size, VKC_BUFFER_DEVICE_LOCAL);
69 div_ctx->mask_buf = vkc_create_buffer(ctx, mask_size, VKC_BUFFER_DEVICE_LOCAL);
70 div_ctx->result_buf = vkc_create_buffer(ctx, result_size, VKC_BUFFER_HOST_VISIBLE);
71
72 if (!div_ctx->embeddings_buf || !div_ctx->centroid_buf || !div_ctx->distances_buf ||
73 !div_ctx->mask_buf || !div_ctx->result_buf) {
74 fprintf(stderr, "[VKD ERROR] Failed to create GPU buffers\n");
75 vkd_destroy(div_ctx);
76 return NULL;
77 }
78
79 /* Upload embeddings to GPU (one-time operation) */
80 printf("[VKD] Uploading %.2f MB of embeddings to GPU...\n",
81 embeddings_size / (1024.0 * 1024.0));
82 vkc_upload_buffer(ctx, div_ctx->embeddings_buf, embeddings, embeddings_size);
83
84 /* Create pipelines */
85 div_ctx->cosine_pipeline = vkc_create_pipeline(ctx, "libs/vulkan-compute/build/cosine_distance.spv",
86 sizeof(uint32_t) * 2);
87 div_ctx->centroid_pipeline = vkc_create_pipeline(ctx, "libs/vulkan-compute/build/centroid_update.spv",
88 sizeof(uint32_t) * 2);
89 div_ctx->reduction_pipeline = vkc_create_pipeline(ctx, "libs/vulkan-compute/build/max_reduction.spv",
90 sizeof(uint32_t));
91
92 if (!div_ctx->cosine_pipeline || !div_ctx->centroid_pipeline || !div_ctx->reduction_pipeline) {
93 fprintf(stderr, "[VKD ERROR] Failed to create pipelines\n");
94 vkd_destroy(div_ctx);
95 return NULL;
96 }
97
98 /* Bind buffers to cosine distance pipeline */
99 vkc_bind_buffer(ctx, div_ctx->cosine_pipeline, 0, div_ctx->embeddings_buf);
100 vkc_bind_buffer(ctx, div_ctx->cosine_pipeline, 1, div_ctx->centroid_buf);
101 vkc_bind_buffer(ctx, div_ctx->cosine_pipeline, 2, div_ctx->distances_buf);
102
103 /* Bind buffers to centroid update pipeline */
104 vkc_bind_buffer(ctx, div_ctx->centroid_pipeline, 0, div_ctx->centroid_buf);
105 /* Binding 1 will be updated per iteration with new embedding */
106
107 /* Bind buffers to reduction pipeline */
108 vkc_bind_buffer(ctx, div_ctx->reduction_pipeline, 0, div_ctx->distances_buf);
109 vkc_bind_buffer(ctx, div_ctx->reduction_pipeline, 1, div_ctx->mask_buf);
110 vkc_bind_buffer(ctx, div_ctx->reduction_pipeline, 2, div_ctx->result_buf);
111
112 /* Allocate CPU scratch buffer for mask */
113 div_ctx->mask = malloc(mask_size);
114 if (!div_ctx->mask) {
115 vkd_destroy(div_ctx);
116 return NULL;
117 }
118
119 /* Keep CPU copy of embeddings for centroid initialization */
120 div_ctx->embeddings = malloc(embeddings_size);
121 if (!div_ctx->embeddings) {
122 vkd_destroy(div_ctx);
123 return NULL;
124 }
125 memcpy(div_ctx->embeddings, embeddings, embeddings_size);
126
127 printf("[VKD] Initialization complete\n");
128 return div_ctx;
129}
130
131/* }}} */
132
133/* {{{ vkd_compute_sequence
134 */
135
136VkComputeResult vkd_compute_sequence(VkDiversityContext* div_ctx,
137 uint32_t start_poem,
138 uint32_t* output_sequence) {
139 if (!div_ctx || !output_sequence || start_poem >= div_ctx->num_poems) {
140 return VKC_ERROR_INIT_FAILED;
141 }
142
143 VkComputeContext* ctx = div_ctx->ctx;
144 uint32_t num_poems = div_ctx->num_poems;
145 uint32_t embedding_dim = div_ctx->embedding_dim;
146
147 /* Initialize mask: all poems available except start_poem */
148 for (uint32_t i = 0; i < num_poems; i++) {
149 div_ctx->mask[i] = (i == start_poem) ? 0 : 1;
150 }
151 vkc_upload_buffer(ctx, div_ctx->mask_buf, div_ctx->mask, num_poems * sizeof(uint32_t));
152
153 /* Initialize centroid with start_poem's embedding */
154 const float* start_embedding = &div_ctx->embeddings[start_poem * embedding_dim];
155 vkc_upload_buffer(ctx, div_ctx->centroid_buf, start_embedding, embedding_dim * sizeof(float));
156 printf("[VKD] Initialized centroid with poem %u's embedding\n", start_poem);
157
158 /* Initialize sequence */
159 output_sequence[0] = start_poem;
160 uint32_t count = 1;
161
162 printf("[VKD] Computing diversity sequence starting from poem %u...\n", start_poem);
163
164 /* Iteratively select most diverse poems */
165 for (uint32_t iter = 1; iter < num_poems; iter++) {
166 /* Step 1: Compute distances from all poems to current centroid */
167 struct {
168 uint32_t num_embeddings;
169 uint32_t embedding_dim;
170 } cosine_push = { num_poems, embedding_dim };
171
172 uint32_t workgroups = (num_poems + 255) / 256;
173 vkc_dispatch(ctx, div_ctx->cosine_pipeline, workgroups, 1, 1, &cosine_push);
174
175 /* Step 2: Find poem with maximum distance (respecting mask) */
176 struct {
177 uint32_t num_poems;
178 } reduction_push = { num_poems };
179
180 /* Single workgroup with stride-based checking covers all poems */
181 /* Each of 256 threads checks multiple poems: thread_id, thread_id+256, thread_id+512, ... */
182 workgroups = 1;
183 vkc_dispatch(ctx, div_ctx->reduction_pipeline, workgroups, 1, 1, &reduction_push);
184
185 /* Step 3: Read result from GPU */
186 uint32_t result[2]; /* [max_index, max_distance_as_uint] */
187 vkc_download_buffer(ctx, div_ctx->result_buf, result, sizeof(result));
188
189 uint32_t selected_poem = result[0];
190
191 /* Verify selection is valid */
192 if (selected_poem == 0xFFFFFFFF) {
193 fprintf(stderr, "[VKD ERROR] No valid poem found (all masked?) at iteration %u\n", iter);
194
195 /* Count available poems */
196 uint32_t available_count = 0;
197 for (uint32_t i = 0; i < num_poems; i++) {
198 if (div_ctx->mask[i] == 1) available_count++;
199 }
200
201 /* Show first 10 mask values for debugging */
202 fprintf(stderr, "[VKD DEBUG] First 10 mask values: ");
203 for (uint32_t i = 0; i < 10 && i < num_poems; i++) {
204 fprintf(stderr, "%u:%u ", i, div_ctx->mask[i]);
205 }
206 fprintf(stderr, "\n[VKD DEBUG] Total available: %u / %u\n", available_count, num_poems);
207 return VKC_ERROR_COMMAND_EXECUTION_FAILED;
208 }
209
210 if (selected_poem >= num_poems || div_ctx->mask[selected_poem] == 0) {
211 fprintf(stderr, "[VKD ERROR] Invalid poem selected: %u (mask: %u)\n",
212 selected_poem, div_ctx->mask[selected_poem]);
213 return VKC_ERROR_COMMAND_EXECUTION_FAILED;
214 }
215
216 /* Add to sequence */
217 output_sequence[iter] = selected_poem;
218
219 /* Step 4: Update mask */
220 div_ctx->mask[selected_poem] = 0;
221 vkc_upload_buffer(ctx, div_ctx->mask_buf, div_ctx->mask, num_poems * sizeof(uint32_t));
222
223 /* Step 5: Update centroid with newly selected poem */
224 /* TODO: Implement GPU-GPU copy for selected embedding */
225 /* For now, we skip centroid update to keep code simple */
226
227 count++;
228
229 /* Progress indicator -- routed through the shared renderer so it obeys
230 * the same TTY / --debug rules as every other stage. Throttled to every
231 * 1000 iterations to match the prior cadence. */
232 if (iter % 1000 == 0) {
233 vkc_progress_update("[VKD] Sequence", iter, num_poems);
234 }
235 }
236
237 vkc_progress_finish();
238 printf("[VKD] Sequence computation complete\n");
239 return VKC_SUCCESS;
240}
241
242/* }}} */
243
244/* {{{ vkd_destroy
245 */
246
247void vkd_destroy(VkDiversityContext* div_ctx) {
248 if (!div_ctx) return;
249
250 VkComputeContext* ctx = div_ctx->ctx;
251
252 if (div_ctx->mask) {
253 free(div_ctx->mask);
254 }
255 if (div_ctx->embeddings) {
256 free(div_ctx->embeddings);
257 }
258
259 if (div_ctx->cosine_pipeline) {
260 vkc_destroy_pipeline(ctx, div_ctx->cosine_pipeline);
261 }
262 if (div_ctx->centroid_pipeline) {
263 vkc_destroy_pipeline(ctx, div_ctx->centroid_pipeline);
264 }
265 if (div_ctx->reduction_pipeline) {
266 vkc_destroy_pipeline(ctx, div_ctx->reduction_pipeline);
267 }
268
269 if (div_ctx->embeddings_buf) {
270 vkc_destroy_buffer(ctx, div_ctx->embeddings_buf);
271 }
272 if (div_ctx->centroid_buf) {
273 vkc_destroy_buffer(ctx, div_ctx->centroid_buf);
274 }
275 if (div_ctx->distances_buf) {
276 vkc_destroy_buffer(ctx, div_ctx->distances_buf);
277 }
278 if (div_ctx->mask_buf) {
279 vkc_destroy_buffer(ctx, div_ctx->mask_buf);
280 }
281 if (div_ctx->result_buf) {
282 vkc_destroy_buffer(ctx, div_ctx->result_buf);
283 }
284
285 free(div_ctx);
286 printf("[VKD] Cleanup complete\n");
287}
288
289/* }}} */
290
291/* }}} */
292
293/* ============================================================================
294 * Batch Processing Implementation
295 * ============================================================================
296 * Enables parallel computation of thousands of diversity sequences with
297 * GPU-side state management for optimal performance (2,600× speedup).
298 */
299
300/* {{{ struct VkDiversityBatchContext
301 */
302
303struct VkDiversityBatchContext {
304 VkComputeContext* ctx;
305
306 /* Dataset parameters */
307 uint32_t num_poems;
308 uint32_t embedding_dim;
309 uint32_t batch_size;
310
311 /* GPU buffers */
312 VkComputeBuffer* embeddings_buf; /* All embeddings (device-local) */
313 VkComputeBuffer* centroids_buf; /* Current centroids for all sequences */
314 VkComputeBuffer* masks_buf; /* Availability masks for all sequences */
315 VkComputeBuffer* counts_buf; /* Poem counts for rolling average */
316 VkComputeBuffer* output_buf; /* Complete sequences output */
317
318 /* Pipelines.
319 * batch_pipeline — original in-shader-tile-loop path (diversity_full.spv).
320 * One dispatch runs many iterations internally.
321 * scan_tile_pipeline — 9-014 dispatch-per-tile path (diversity_scan_tile.spv).
322 * Scans one tile of one iteration; accumulates per-workgroup
323 * running max into running_max_distance_buf / running_max_index_buf.
324 * commit_iteration_pipeline — 9-014 commit step (diversity_commit_iteration.spv).
325 * Reads the running max, updates state, resets max for next iter. */
326 VkComputePipeline* batch_pipeline;
327 VkComputePipeline* scan_tile_pipeline;
328 VkComputePipeline* commit_iteration_pipeline;
329
330 /* 9-014 running max storage buffers (used by the dispatch-per-tile path).
331 * One float and one uint per workgroup, persisting across the scan-tile
332 * dispatches that make up one iteration, then read and reset by commit. */
333 VkComputeBuffer* running_max_distance_buf;
334 VkComputeBuffer* running_max_index_buf;
335
336 /* Selections buffer (host-visible for reading back selected poems) */
337 VkComputeBuffer* selections_buf;
338};
339
340/* }}} */
341
342/* {{{ vkd_batch_init
343 */
344
345VkDiversityBatchContext* vkd_batch_init(VkComputeContext* ctx,
346 const uint16_t* embeddings_fp16,
347 uint32_t num_poems,
348 uint32_t embedding_dim,
349 uint32_t batch_size,
350 const uint32_t* start_indices) {
351 if (!ctx || !embeddings_fp16 || !start_indices || num_poems == 0 ||
352 embedding_dim == 0 || batch_size == 0 || batch_size > 3584) {
353 return NULL;
354 }
355 if (embedding_dim % 2 != 0) {
356 /* The shader processes pairs of dims per loop iteration to use
357 * unpackHalf2x16 efficiently; an odd dim count would leave a tail
358 * half that the current shader does not handle. Embedding models
359 * that produce odd dims would need a shader update to support. */
360 fprintf(stderr, "[VKD Batch ERROR] embedding_dim must be even for the FP16-packed shader; got %u\n", embedding_dim);
361 return NULL;
362 }
363
364 VkDiversityBatchContext* batch_ctx = calloc(1, sizeof(VkDiversityBatchContext));
365 if (!batch_ctx) {
366 return NULL;
367 }
368
369 batch_ctx->ctx = ctx;
370 batch_ctx->num_poems = num_poems;
371 batch_ctx->embedding_dim = embedding_dim;
372 batch_ctx->batch_size = batch_size;
373
374 printf("[VKD Batch] Initializing batch context...\n");
375 printf(" Poems: %u, Dimensions: %u, Batch size: %u\n",
376 num_poems, embedding_dim, batch_size);
377
378 /* Calculate buffer sizes. Embeddings are FP16-packed: each value is
379 * 2 bytes instead of the 4 bytes the old FP32 path used, so the GPU
380 * buffer is exactly half the size for the same poem count. */
381 size_t embeddings_size = (size_t)num_poems * embedding_dim * sizeof(uint16_t);
382 size_t centroids_size = (size_t)batch_size * embedding_dim * sizeof(float);
383 size_t masks_size = (size_t)batch_size * num_poems * sizeof(uint32_t);
384 size_t counts_size = (size_t)batch_size * sizeof(uint32_t);
385 size_t output_size = (size_t)batch_size * num_poems * sizeof(uint32_t);
386 size_t selections_size = (size_t)batch_size * sizeof(uint32_t);
387
388 printf("[VKD Batch] Buffer sizes:\n");
389 printf(" Embeddings (FP16): %.2f MB\n", embeddings_size / (1024.0 * 1024.0));
390 printf(" Centroids (FP32): %.2f MB\n", centroids_size / (1024.0 * 1024.0));
391 printf(" Masks: %.2f MB\n", masks_size / (1024.0 * 1024.0));
392 printf(" Total GPU memory: %.2f MB\n",
393 (embeddings_size + centroids_size + masks_size + counts_size + output_size) / (1024.0 * 1024.0));
394
395 /* Create GPU buffers */
396 batch_ctx->embeddings_buf = vkc_create_buffer(ctx, embeddings_size, VKC_BUFFER_DEVICE_LOCAL);
397 batch_ctx->centroids_buf = vkc_create_buffer(ctx, centroids_size, VKC_BUFFER_DEVICE_LOCAL);
398 batch_ctx->masks_buf = vkc_create_buffer(ctx, masks_size, VKC_BUFFER_DEVICE_LOCAL);
399 batch_ctx->counts_buf = vkc_create_buffer(ctx, counts_size, VKC_BUFFER_DEVICE_LOCAL);
400 batch_ctx->output_buf = vkc_create_buffer(ctx, output_size, VKC_BUFFER_DEVICE_LOCAL);
401 batch_ctx->selections_buf = vkc_create_buffer(ctx, selections_size, VKC_BUFFER_HOST_VISIBLE);
402
403 if (!batch_ctx->embeddings_buf || !batch_ctx->centroids_buf || !batch_ctx->masks_buf ||
404 !batch_ctx->counts_buf || !batch_ctx->output_buf || !batch_ctx->selections_buf) {
405 fprintf(stderr, "[VKD Batch ERROR] Failed to create GPU buffers\n");
406 vkd_batch_destroy(batch_ctx);
407 return NULL;
408 }
409
410 /* Upload FP16-packed embeddings to the GPU. The shader reads these
411 * via unpackHalf2x16 on the fly; no conversion happens here. */
412 printf("[VKD Batch] Uploading %.2f MB of FP16 embeddings to GPU...\n",
413 embeddings_size / (1024.0 * 1024.0));
414 vkc_upload_buffer(ctx, batch_ctx->embeddings_buf, embeddings_fp16, embeddings_size);
415
416 /* Initialize centroids with starting poem embeddings, converted back
417 * to FP32 since the centroid buffer is FP32 (the shader reads it as
418 * shared-memory FP32; only the embedding table is FP16). */
419 float* initial_centroids = malloc(centroids_size);
420 if (!initial_centroids) {
421 vkd_batch_destroy(batch_ctx);
422 return NULL;
423 }
424
425 for (uint32_t i = 0; i < batch_size; i++) {
426 uint32_t start_poem = start_indices[i];
427 if (start_poem >= num_poems) {
428 fprintf(stderr, "[VKD Batch ERROR] Invalid start index: %u\n", start_poem);
429 free(initial_centroids);
430 vkd_batch_destroy(batch_ctx);
431 return NULL;
432 }
433 const uint16_t* src = &embeddings_fp16[(size_t)start_poem * embedding_dim];
434 float* dst = &initial_centroids[(size_t)i * embedding_dim];
435 for (uint32_t d = 0; d < embedding_dim; d++) {
436 dst[d] = vkc_fp16_to_fp32(src[d]);
437 }
438 }
439 vkc_upload_buffer(ctx, batch_ctx->centroids_buf, initial_centroids, centroids_size);
440 free(initial_centroids);
441
442 /* Initialize masks (all 1s except starting poems) */
443 uint32_t* initial_masks = malloc(masks_size);
444 if (!initial_masks) {
445 vkd_batch_destroy(batch_ctx);
446 return NULL;
447 }
448
449 for (uint32_t seq = 0; seq < batch_size; seq++) {
450 uint32_t start_poem = start_indices[seq];
451 for (uint32_t p = 0; p < num_poems; p++) {
452 initial_masks[seq * num_poems + p] = (p == start_poem) ? 0 : 1;
453 }
454 }
455 vkc_upload_buffer(ctx, batch_ctx->masks_buf, initial_masks, masks_size);
456 free(initial_masks);
457
458 /* Initialize counts (all 1s) */
459 uint32_t* initial_counts = malloc(counts_size);
460 if (!initial_counts) {
461 vkd_batch_destroy(batch_ctx);
462 return NULL;
463 }
464 for (uint32_t i = 0; i < batch_size; i++) {
465 initial_counts[i] = 1;
466 }
467 vkc_upload_buffer(ctx, batch_ctx->counts_buf, initial_counts, counts_size);
468 free(initial_counts);
469
470 /* Initialize output buffer with starting poems */
471 uint32_t* initial_output = calloc(batch_size * num_poems, sizeof(uint32_t));
472 if (!initial_output) {
473 vkd_batch_destroy(batch_ctx);
474 return NULL;
475 }
476 for (uint32_t i = 0; i < batch_size; i++) {
477 initial_output[i * num_poems] = start_indices[i];
478 }
479 vkc_upload_buffer(ctx, batch_ctx->output_buf, initial_output, output_size);
480 free(initial_output);
481
482 /* Create pipeline. Uses diversity_full.spv: the workgroup runs a
483 * chunk of the iteration loop internally instead of one iteration
484 * per dispatch. Push constants are {num_poems, embedding_dim,
485 * start_slot, slot_count}. The size MUST match what
486 * vkd_batch_compute_chunk pushes — under-allocating here means the
487 * tail of the push-constant struct silently reads as zero in the
488 * shader, the chunk loop runs zero iterations, and every dispatch
489 * returns instantly having done no work. */
490 batch_ctx->batch_pipeline = vkc_create_pipeline(ctx, "libs/vulkan-compute/build/diversity_full.spv",
491 sizeof(uint32_t) * 5); /* num_poems, embedding_dim, start_slot, slot_count, tile_size */
492 if (!batch_ctx->batch_pipeline) {
493 fprintf(stderr, "[VKD Batch ERROR] Failed to create pipeline\n");
494 vkd_batch_destroy(batch_ctx);
495 return NULL;
496 }
497
498 /* Bind buffers */
499 vkc_bind_buffer(ctx, batch_ctx->batch_pipeline, 0, batch_ctx->embeddings_buf);
500 vkc_bind_buffer(ctx, batch_ctx->batch_pipeline, 1, batch_ctx->centroids_buf);
501 vkc_bind_buffer(ctx, batch_ctx->batch_pipeline, 2, batch_ctx->masks_buf);
502 vkc_bind_buffer(ctx, batch_ctx->batch_pipeline, 3, batch_ctx->counts_buf);
503 vkc_bind_buffer(ctx, batch_ctx->batch_pipeline, 4, batch_ctx->output_buf);
504
505 /* 9-014 dispatch-per-tile path setup: running_max storage buffers and
506 * the two new pipelines. Allocated lazily here so callers that only
507 * use the in-shader-tile-loop path do not pay for them. */
508 size_t running_max_distance_size = (size_t)batch_size * sizeof(float);
509 size_t running_max_index_size = (size_t)batch_size * sizeof(uint32_t);
510
511 batch_ctx->running_max_distance_buf =
512 vkc_create_buffer(ctx, running_max_distance_size, VKC_BUFFER_DEVICE_LOCAL);
513 batch_ctx->running_max_index_buf =
514 vkc_create_buffer(ctx, running_max_index_size, VKC_BUFFER_DEVICE_LOCAL);
515 if (!batch_ctx->running_max_distance_buf || !batch_ctx->running_max_index_buf) {
516 fprintf(stderr, "[VKD Batch ERROR] Failed to create running-max buffers\n");
517 vkd_batch_destroy(batch_ctx);
518 return NULL;
519 }
520
521 /* Initialize running max to (-inf, sentinel). The commit shader resets
522 * to the same values after each iteration, but we need a known-good
523 * starting state for the very first iteration's first tile-scan. */
524 float* initial_max_dist = malloc(running_max_distance_size);
525 uint32_t* initial_max_idx = malloc(running_max_index_size);
526 if (!initial_max_dist || !initial_max_idx) {
527 free(initial_max_dist);
528 free(initial_max_idx);
529 vkd_batch_destroy(batch_ctx);
530 return NULL;
531 }
532 for (uint32_t i = 0; i < batch_size; i++) {
533 initial_max_dist[i] = -1e9f;
534 initial_max_idx[i] = 0xFFFFFFFFu;
535 }
536 vkc_upload_buffer(ctx, batch_ctx->running_max_distance_buf,
537 initial_max_dist, running_max_distance_size);
538 vkc_upload_buffer(ctx, batch_ctx->running_max_index_buf,
539 initial_max_idx, running_max_index_size);
540 free(initial_max_dist);
541 free(initial_max_idx);
542
543 /* Create the scan-tile pipeline. Push constants: num_poems, embedding_dim,
544 * tile_start, tile_size — four uints. Reads embeddings/centroids/masks,
545 * writes running_max_distance/running_max_index. */
546 batch_ctx->scan_tile_pipeline = vkc_create_pipeline(ctx, "libs/vulkan-compute/build/diversity_scan_tile.spv",
547 sizeof(uint32_t) * 4);
548 if (!batch_ctx->scan_tile_pipeline) {
549 fprintf(stderr, "[VKD Batch ERROR] Failed to create scan_tile pipeline\n");
550 vkd_batch_destroy(batch_ctx);
551 return NULL;
552 }
553 vkc_bind_buffer(ctx, batch_ctx->scan_tile_pipeline, 0, batch_ctx->embeddings_buf);
554 vkc_bind_buffer(ctx, batch_ctx->scan_tile_pipeline, 1, batch_ctx->centroids_buf);
555 vkc_bind_buffer(ctx, batch_ctx->scan_tile_pipeline, 2, batch_ctx->masks_buf);
556 vkc_bind_buffer(ctx, batch_ctx->scan_tile_pipeline, 3, batch_ctx->running_max_distance_buf);
557 vkc_bind_buffer(ctx, batch_ctx->scan_tile_pipeline, 4, batch_ctx->running_max_index_buf);
558
559 /* Create the commit-iteration pipeline. Push constants: num_poems,
560 * embedding_dim, slot — three uints. Reads embeddings, writes centroid/
561 * mask/count/output and resets running_max. */
562 batch_ctx->commit_iteration_pipeline = vkc_create_pipeline(ctx, "libs/vulkan-compute/build/diversity_commit_iteration.spv",
563 sizeof(uint32_t) * 3);
564 if (!batch_ctx->commit_iteration_pipeline) {
565 fprintf(stderr, "[VKD Batch ERROR] Failed to create commit_iteration pipeline\n");
566 vkd_batch_destroy(batch_ctx);
567 return NULL;
568 }
569 vkc_bind_buffer(ctx, batch_ctx->commit_iteration_pipeline, 0, batch_ctx->embeddings_buf);
570 vkc_bind_buffer(ctx, batch_ctx->commit_iteration_pipeline, 1, batch_ctx->centroids_buf);
571 vkc_bind_buffer(ctx, batch_ctx->commit_iteration_pipeline, 2, batch_ctx->masks_buf);
572 vkc_bind_buffer(ctx, batch_ctx->commit_iteration_pipeline, 3, batch_ctx->counts_buf);
573 vkc_bind_buffer(ctx, batch_ctx->commit_iteration_pipeline, 4, batch_ctx->output_buf);
574 vkc_bind_buffer(ctx, batch_ctx->commit_iteration_pipeline, 5, batch_ctx->running_max_distance_buf);
575 vkc_bind_buffer(ctx, batch_ctx->commit_iteration_pipeline, 6, batch_ctx->running_max_index_buf);
576
577 // Three compute shaders back this context: tile-scan (finds the farthest
578 // candidate in a tile), commit (records the per-iteration winner), and the
579 // legacy single-dispatch batch shader. They are algorithm stages, not
580 // parallel workers -- distinct from how the *work* is chunked into batches.
581 printf("[VKD Batch] Initialization complete (3 compute shaders: tile-scan, commit, batch)\n");
582 return batch_ctx;
583}
584
585/* }}} */
586
587/* {{{ vkd_batch_compute_chunk
588 */
589
590VkComputeResult vkd_batch_compute_chunk(VkDiversityBatchContext* batch_ctx,
591 uint32_t start_slot,
592 uint32_t slot_count,
593 uint32_t tile_size) {
594 if (!batch_ctx || slot_count == 0) {
595 return VKC_ERROR_INIT_FAILED;
596 }
597 if (start_slot + slot_count > batch_ctx->num_poems) {
598 /* Caller asked for more slots than exist; refuse rather than write
599 * past the end of output_buf. */
600 return VKC_ERROR_INIT_FAILED;
601 }
602 if (tile_size == 0) {
603 /* Treat zero as "no tiling" — one tile covers the entire candidate
604 * range. The shader has the same fallback baked in but this is
605 * cheap to defend against here and produces an explicit value
606 * for the validation layers to inspect. */
607 tile_size = batch_ctx->num_poems;
608 }
609
610 VkComputeContext* ctx = batch_ctx->ctx;
611
612 /* Push constants describe the dataset shape, the slice of work this
613 * dispatch is responsible for, and the tile granularity of the inner
614 * scan. The shader writes output slots [start_slot, start_slot + slot_count)
615 * and walks the candidate range in tiles of tile_size, with the running
616 * max accumulator persisting across tiles within one iteration. */
617 struct {
618 uint32_t num_poems;
619 uint32_t embedding_dim;
620 uint32_t start_slot;
621 uint32_t slot_count;
622 uint32_t tile_size;
623 } push_constants = {
624 batch_ctx->num_poems,
625 batch_ctx->embedding_dim,
626 start_slot,
627 slot_count,
628 tile_size
629 };
630
631 /* One workgroup per sequence in the batch. Each workgroup runs
632 * slot_count iterations internally. We must propagate the dispatch
633 * result — the previous version of this code swallowed errors and
634 * returned VKC_SUCCESS unconditionally, which caused a device-lost
635 * failure in one dispatch to silently break every following dispatch. */
636 return vkc_dispatch(ctx, batch_ctx->batch_pipeline,
637 batch_ctx->batch_size, 1, 1, &push_constants);
638}
639
640/* }}} */
641
642/* {{{ vkd_batch_compute_chunk_pipelined
643 *
644 * 9-014 dispatch-per-tile + pipelining: the same effective work as
645 * vkd_batch_compute_chunk, but each iteration is split into N tile-scan
646 * dispatches plus one commit dispatch, all submitted via the async
647 * pipeline pool so the CPU stays ahead of the GPU.
648 *
649 * Compared to vkd_batch_compute_chunk:
650 * - Cache hit rate is higher because the fence wait between tile
651 * dispatches enforces hard grid sync — all workgroups finish tile K
652 * before any start tile K+1, so the L2 holds exactly one tile's
653 * worth of embedding data at a time.
654 * - CPU dispatch overhead is hidden because the async pool lets the
655 * CPU record dispatch N+1 while the GPU runs dispatch N.
656 * - More total dispatches (slot_count * (num_tiles + 1) instead of 1
657 * per chunk), but each is much shorter and the pipelining hides
658 * the per-dispatch cost.
659 *
660 * Parameters match vkd_batch_compute_chunk: start_slot, slot_count, and
661 * tile_size. tile_size is now load-bearing (the chunked-into-tiles flow
662 * only makes sense with a real tile size); passing 0 or num_poems
663 * collapses to one tile per iteration, which is the same shape as the
664 * non-tiled baseline only with extra dispatch overhead — useful for
665 * sanity-checking but not for production.
666 */
667VkComputeResult vkd_batch_compute_chunk_pipelined(VkDiversityBatchContext* batch_ctx,
668 uint32_t start_slot,
669 uint32_t slot_count,
670 uint32_t tile_size) {
671 if (!batch_ctx || slot_count == 0) {
672 return VKC_ERROR_INIT_FAILED;
673 }
674 if (start_slot + slot_count > batch_ctx->num_poems) {
675 return VKC_ERROR_INIT_FAILED;
676 }
677 if (tile_size == 0 || tile_size > batch_ctx->num_poems) {
678 tile_size = batch_ctx->num_poems;
679 }
680
681 VkComputeContext* ctx = batch_ctx->ctx;
682 uint32_t num_tiles = (batch_ctx->num_poems + tile_size - 1) / tile_size;
683
684 struct {
685 uint32_t num_poems;
686 uint32_t embedding_dim;
687 uint32_t tile_start;
688 uint32_t tile_size;
689 } scan_pc;
690 struct {
691 uint32_t num_poems;
692 uint32_t embedding_dim;
693 uint32_t slot;
694 } commit_pc;
695
696 /* For each iteration in this chunk: dispatch N tile-scans, then one
697 * commit. All dispatches go to the async pool. The compute-to-compute
698 * memory barrier at the head of each command buffer (added by
699 * vkc_dispatch_async) ensures the running_max writes from tile K are
700 * visible to tile K+1's reads, and the commit's reads see the final
701 * running max from the last tile. */
702 for (uint32_t iter = 0; iter < slot_count; iter++) {
703 uint32_t slot = start_slot + iter;
704
705 for (uint32_t t = 0; t < num_tiles; t++) {
706 uint32_t tile_start = t * tile_size;
707 uint32_t this_tile_size = tile_size;
708 if (tile_start + this_tile_size > batch_ctx->num_poems) {
709 this_tile_size = batch_ctx->num_poems - tile_start;
710 }
711 scan_pc.num_poems = batch_ctx->num_poems;
712 scan_pc.embedding_dim = batch_ctx->embedding_dim;
713 scan_pc.tile_start = tile_start;
714 scan_pc.tile_size = this_tile_size;
715
716 VkComputeResult r = vkc_dispatch_async(ctx, batch_ctx->scan_tile_pipeline,
717 batch_ctx->batch_size, 1, 1,
718 &scan_pc);
719 if (r != VKC_SUCCESS) {
720 return r;
721 }
722 }
723
724 commit_pc.num_poems = batch_ctx->num_poems;
725 commit_pc.embedding_dim = batch_ctx->embedding_dim;
726 commit_pc.slot = slot;
727
728 VkComputeResult r = vkc_dispatch_async(ctx, batch_ctx->commit_iteration_pipeline,
729 batch_ctx->batch_size, 1, 1,
730 &commit_pc);
731 if (r != VKC_SUCCESS) {
732 return r;
733 }
734 }
735
736 /* Drain the pipeline before returning so callers can rely on all
737 * work being done at the point of return. Callers that want to keep
738 * the pipeline warm across chunks can refactor to defer the drain. */
739 return vkc_wait_async_all(ctx);
740}
741
742/* }}} */
743
744/* {{{ vkd_batch_download_sequences
745 */
746
747VkComputeResult vkd_batch_download_sequences(VkDiversityBatchContext* batch_ctx,
748 uint32_t* output_sequences) {
749 if (!batch_ctx || !output_sequences) {
750 return VKC_ERROR_INIT_FAILED;
751 }
752
753 size_t output_size = batch_ctx->batch_size * batch_ctx->num_poems * sizeof(uint32_t);
754 vkc_download_buffer(batch_ctx->ctx, batch_ctx->output_buf, output_sequences, output_size);
755
756 return VKC_SUCCESS;
757}
758
759/* }}} */
760
761/* {{{ vkd_batch_destroy
762 */
763
764void vkd_batch_destroy(VkDiversityBatchContext* batch_ctx) {
765 if (!batch_ctx) return;
766
767 VkComputeContext* ctx = batch_ctx->ctx;
768
769 if (batch_ctx->batch_pipeline) {
770 vkc_destroy_pipeline(ctx, batch_ctx->batch_pipeline);
771 }
772 if (batch_ctx->scan_tile_pipeline) {
773 vkc_destroy_pipeline(ctx, batch_ctx->scan_tile_pipeline);
774 }
775 if (batch_ctx->commit_iteration_pipeline) {
776 vkc_destroy_pipeline(ctx, batch_ctx->commit_iteration_pipeline);
777 }
778
779 if (batch_ctx->embeddings_buf) {
780 vkc_destroy_buffer(ctx, batch_ctx->embeddings_buf);
781 }
782 if (batch_ctx->centroids_buf) {
783 vkc_destroy_buffer(ctx, batch_ctx->centroids_buf);
784 }
785 if (batch_ctx->masks_buf) {
786 vkc_destroy_buffer(ctx, batch_ctx->masks_buf);
787 }
788 if (batch_ctx->counts_buf) {
789 vkc_destroy_buffer(ctx, batch_ctx->counts_buf);
790 }
791 if (batch_ctx->output_buf) {
792 vkc_destroy_buffer(ctx, batch_ctx->output_buf);
793 }
794 if (batch_ctx->selections_buf) {
795 vkc_destroy_buffer(ctx, batch_ctx->selections_buf);
796 }
797 if (batch_ctx->running_max_distance_buf) {
798 vkc_destroy_buffer(ctx, batch_ctx->running_max_distance_buf);
799 }
800 if (batch_ctx->running_max_index_buf) {
801 vkc_destroy_buffer(ctx, batch_ctx->running_max_index_buf);
802 }
803
804 free(batch_ctx);
805 printf("[VKD Batch] Cleanup complete\n");
806}
807
808/* }}} */
809