libs/vulkan-compute/include/vk_diversity.h
1/* vk_diversity.h - High-level API for diversity sequence generation
2 *
3 * This header provides a simplified interface for computing diversity
4 * sequences on the GPU using the Vulkan compute infrastructure.
5 */
6
7#ifndef VK_DIVERSITY_H
8#define VK_DIVERSITY_H
9
10#include "vk_compute.h"
11
12#ifdef __cplusplus
13extern "C" {
14#endif
15
16/* Opaque handle for diversity sequence computation session */
17typedef struct VkDiversityContext VkDiversityContext;
18
19/* Initialize diversity sequence computation
20 *
21 * Parameters:
22 * ctx - Vulkan compute context
23 * embeddings - All poem embeddings (num_poems * embedding_dim floats)
24 * num_poems - Number of poems
25 * embedding_dim - Dimension of embeddings (e.g., 768)
26 *
27 * Returns: Diversity context handle or NULL on error
28 */
29VkDiversityContext* vkd_init(VkComputeContext* ctx,
30 const float* embeddings,
31 uint32_t num_poems,
32 uint32_t embedding_dim);
33
34/* Compute a single diversity sequence
35 *
36 * Parameters:
37 * div_ctx - Diversity context
38 * start_poem - Index of starting poem (0 to num_poems-1)
39 * output_sequence - Output buffer for sequence (num_poems indices)
40 *
41 * Returns: VKC_SUCCESS or error code
42 */
43VkComputeResult vkd_compute_sequence(VkDiversityContext* div_ctx,
44 uint32_t start_poem,
45 uint32_t* output_sequence);
46
47/* Cleanup diversity context */
48void vkd_destroy(VkDiversityContext* div_ctx);
49
50/* ===========================================================================
51 * Batch Processing API - Parallel computation of multiple sequences
52 * ===========================================================================
53 * These functions enable computing thousands of diversity sequences
54 * simultaneously with GPU-side state management for optimal performance.
55 */
56
57/* Opaque handle for batch diversity computation */
58typedef struct VkDiversityBatchContext VkDiversityBatchContext;
59
60/* Initialize batch diversity computation
61 *
62 * Embeddings are FP16, packed two per uint, with low 16 bits = value at
63 * even-index dim and high 16 bits = value at odd-index dim. Caller is
64 * responsible for the FP32 -> FP16 conversion via vkc_fp32_to_fp16().
65 * embedding_dim MUST be even (true for 768 and 2560; check before calling).
66 *
67 * Parameters:
68 * ctx - Vulkan compute context
69 * embeddings_fp16 - All poem embeddings, FP16 packed:
70 * (num_poems * embedding_dim / 2) uints, i.e.
71 * (num_poems * embedding_dim * 2) bytes.
72 * num_poems - Total number of poems (e.g., 7797)
73 * embedding_dim - Embedding dimension; must be even (e.g., 768, 2560)
74 * batch_size - Number of sequences to compute in parallel (e.g., 3584)
75 * start_indices - Array of starting poem indices (batch_size elements)
76 *
77 * Returns: Batch context handle or NULL on error
78 *
79 * Note: Batch size should be <= 3584 for optimal GPU utilization
80 */
81VkDiversityBatchContext* vkd_batch_init(VkComputeContext* ctx,
82 const uint16_t* embeddings_fp16,
83 uint32_t num_poems,
84 uint32_t embedding_dim,
85 uint32_t batch_size,
86 const uint32_t* start_indices);
87
88/* Run a chunk of diversity-sequence iterations on the GPU.
89 *
90 * Each workgroup advances its sequence by `slot_count` slots, starting at
91 * output slot `start_slot`. Centroid + count + mask state persists in the
92 * shared storage buffers between calls, so a subsequent call with
93 * start_slot = start_slot + slot_count resumes exactly where this one
94 * left off.
95 *
96 * The chunked design exists because attempting to compute every iteration
97 * in a single dispatch trips the kernel GPU watchdog (~10 seconds on
98 * Linux+NVIDIA with an active display). Calling this in a loop with a
99 * chunk size that yields under ~1-2 seconds of GPU work per call avoids
100 * that and still amortizes per-dispatch overhead nearly perfectly
101 * compared to the old per-iteration approach (8358 dispatches per batch).
102 *
103 * Parameters:
104 * batch_ctx - Batch context
105 * start_slot - First output-sequence slot to write (1 on the first call;
106 * slot 0 is the seed, written by vkd_batch_init)
107 * slot_count - How many slots to write in this call
108 * tile_size - Number of candidate poems per L2-friendly tile in the
109 * inner scan. Pass batch_ctx->num_poems (or 0) for the
110 * non-tiled baseline; pass a smaller value to enable the
111 * 9-014 tiling optimization. A reasonable derivation is
112 * floor(L2_BYTES * 0.85 / (embedding_dim * 2)) since each
113 * FP16-packed candidate is embedding_dim * 2 bytes.
114 *
115 * Returns: VKC_SUCCESS, or the error code from the underlying dispatch
116 * (e.g. VKC_ERROR_COMMAND_EXECUTION_FAILED on device-lost). The
117 * caller is responsible for stopping the loop and reporting on
118 * any non-success return.
119 */
120VkComputeResult vkd_batch_compute_chunk(VkDiversityBatchContext* batch_ctx,
121 uint32_t start_slot,
122 uint32_t slot_count,
123 uint32_t tile_size);
124
125/* 9-014: Dispatch-per-tile + pipelined version of vkd_batch_compute_chunk.
126 *
127 * Same logical work as vkd_batch_compute_chunk but the per-iteration tile
128 * loop is unrolled into separate dispatches with a fence wait between
129 * each tile (hard grid sync — guarantees all workgroups finish tile K
130 * before any start tile K+1). The async command-buffer pool in
131 * VkComputeContext lets the CPU stay ahead of the GPU by N submissions,
132 * hiding per-dispatch CPU overhead. See 9-014 for the architecture
133 * rationale and expected speedups.
134 *
135 * Parameters mirror vkd_batch_compute_chunk. The two are interchangeable
136 * from the caller's perspective; they produce identical sequences (by
137 * associativity of max across tile boundaries) but reach the answer via
138 * different access patterns and synchronization granularity.
139 *
140 * Drains the async pool before returning, so on return all the chunk's
141 * work is complete and the output buffer is safe to read.
142 */
143VkComputeResult vkd_batch_compute_chunk_pipelined(VkDiversityBatchContext* batch_ctx,
144 uint32_t start_slot,
145 uint32_t slot_count,
146 uint32_t tile_size);
147
148/* Download complete sequences from GPU
149 *
150 * Parameters:
151 * batch_ctx - Batch context
152 * output_sequences - Output buffer (batch_size * num_poems indices)
153 *
154 * Returns: VKC_SUCCESS or error code
155 */
156VkComputeResult vkd_batch_download_sequences(VkDiversityBatchContext* batch_ctx,
157 uint32_t* output_sequences);
158
159/* Cleanup batch context */
160void vkd_batch_destroy(VkDiversityBatchContext* batch_ctx);
161
162#ifdef __cplusplus
163}
164#endif
165
166#endif /* VK_DIVERSITY_H */
167