libs/vulkan-compute/test-batch-full.lua
1#!/usr/bin/env luajit
2-- test-batch-full.lua - Test batch diversity sequence generation with full dataset
3--
4-- This script tests the GPU-accelerated batch parallel diversity sequence
5-- generation on the complete dataset of 7,797 poems. Expected runtime: 20-30 seconds.
6
7local ffi = require("ffi")
8
9-- {{{ Load dependencies
10-- Set library path for vk_compute to find the shared library
11_G.VK_COMPUTE_LIB = "./build/libvkcompute.so"
12
13-- Add library paths
14package.path = package.path .. ";./lua/?.lua;./?/init.lua"
15package.path = package.path .. ";/home/ritz/programming/ai-stuff/libs/lua/?.lua"
16
17local vk = require("vk_compute")
18
19-- Load JSON library
20local json = require("dkjson")
21-- }}}
22
23-- {{{ Configuration
24local EMBEDDINGS_FILE = "/home/ritz/programming/ai-stuff/neocities-modernization/assets/embeddings/embeddinggemma_latest/embeddings.json"
25local OUTPUT_FILE = "/home/ritz/programming/ai-stuff/neocities-modernization/output/diversity-cache-gpu-batch.bin"
26local NUM_POEMS = 7797
27local EMBEDDING_DIM = 768
28local BATCH_SIZE = 3584 -- Optimal for GTX 1080 Ti
29
30-- Parse command line arguments
31local DEBUG = false
32for i = 1, #arg do
33 if arg[i] == "--debug" then
34 DEBUG = true
35 end
36end
37
38local function debug_print(...)
39 if DEBUG then
40 print(string.format("[DEBUG] %s", string.format(...)))
41 io.stdout:flush()
42 end
43end
44-- }}}
45
46-- {{{ local function load_embeddings()
47-- Load embeddings from JSON file
48local function load_embeddings(filepath)
49 print(string.format("[Loading] Reading embeddings from: %s", filepath))
50 debug_print("Opening file...")
51
52 local f = io.open(filepath, "rb")
53 if not f then
54 error("Failed to open embeddings file: " .. filepath)
55 end
56
57 debug_print("Reading file contents...")
58 local content = f:read("*all")
59 f:close()
60 debug_print("Read %d bytes", #content)
61
62 debug_print("Decoding JSON...")
63 local data = json.decode(content)
64 debug_print("JSON decoded successfully")
65
66 if not data.embeddings then
67 error("Invalid embeddings file format: missing 'embeddings' field")
68 end
69
70 -- Convert to flat array format expected by GPU
71 -- JSON structure: { metadata: {...}, embeddings: [{poem_index, id, embedding: [...]}, ...] }
72 debug_print("Converting to flat array format...")
73 local embeddings = {}
74 local idx = 1
75
76 -- Sort by poem_index to ensure correct order
77 debug_print("Sorting %d embeddings by poem_index...", #data.embeddings)
78 table.sort(data.embeddings, function(a, b)
79 return a.poem_index < b.poem_index
80 end)
81 debug_print("Sort complete")
82
83 for i, poem_data in ipairs(data.embeddings) do
84 local embedding = poem_data.embedding
85
86 if not embedding then
87 error(string.format("Missing embedding for poem index %d", poem_data.poem_index))
88 end
89
90 if #embedding ~= EMBEDDING_DIM then
91 error(string.format("Poem index %d has wrong embedding dimension: %d (expected %d)",
92 poem_data.poem_index, #embedding, EMBEDDING_DIM))
93 end
94
95 for j = 1, EMBEDDING_DIM do
96 embeddings[idx] = embedding[j]
97 idx = idx + 1
98 end
99 end
100
101 local actual_poems = #data.embeddings
102 print(string.format("[Loading] Loaded %d embeddings (%d floats total)",
103 actual_poems, #embeddings))
104
105 if actual_poems ~= NUM_POEMS then
106 print(string.format("[Warning] Expected %d poems but found %d", NUM_POEMS, actual_poems))
107 end
108
109 return embeddings
110end
111-- }}}
112
113-- {{{ local function main()
114local function main()
115 print("=" .. string.rep("=", 78))
116 print(" Batch Parallel Diversity Sequence Generation - Full Dataset Test")
117 print("=" .. string.rep("=", 78))
118 if DEBUG then
119 print(" [DEBUG MODE ENABLED]")
120 end
121 print()
122
123 -- Load embeddings
124 debug_print("Starting load_embeddings()")
125 local embeddings = load_embeddings(EMBEDDINGS_FILE)
126 debug_print("Embeddings loaded, table size: %d", #embeddings)
127
128 -- Initialize Vulkan
129 print("\n[Vulkan] Initializing compute context...")
130 debug_print("Calling vk.init(false)")
131 local ctx = vk.init(false) -- disable validation layers for performance
132 debug_print("vk.init() returned: %s", tostring(ctx))
133
134 if not ctx then
135 error("Failed to initialize Vulkan context")
136 end
137
138 print()
139 print("[Batch] Starting batch parallel computation...")
140 print(string.format(" Poems: %d", NUM_POEMS))
141 print(string.format(" Embedding Dim: %d", EMBEDDING_DIM))
142 print(string.format(" Batch Size: %d", BATCH_SIZE))
143 print(string.format(" Expected Time: 20-30 seconds"))
144 print(string.format(" Output: %s", OUTPUT_FILE))
145 print()
146
147 local start_time = os.clock()
148 debug_print("Starting batch processing at time %.3f", start_time)
149
150 -- Run batch processing
151 debug_print("Calling compute_all_diversity_sequences_batched()")
152 local sequences = vk.compute_all_diversity_sequences_batched(
153 ctx,
154 embeddings,
155 NUM_POEMS,
156 EMBEDDING_DIM,
157 OUTPUT_FILE,
158 BATCH_SIZE
159 )
160 debug_print("Batch processing returned")
161
162 local elapsed = os.clock() - start_time
163
164 print()
165 print("=" .. string.rep("=", 78))
166 print(" Results")
167 print("=" .. string.rep("=", 78))
168 print(string.format(" Total Time: %.2f seconds", elapsed))
169 print(string.format(" Sequences Generated: %d", NUM_POEMS))
170 print(string.format(" Average Time per Sequence: %.4f seconds", elapsed / NUM_POEMS))
171 print(string.format(" Output File: %s", OUTPUT_FILE))
172 print()
173
174 -- Validate first sequence
175 if sequences and sequences[0] then
176 print(" First sequence (poem 0):")
177 local seq = sequences[0]
178 print(string.format(" First 10 poems: %s",
179 table.concat({seq[1], seq[2], seq[3], seq[4], seq[5],
180 seq[6], seq[7], seq[8], seq[9], seq[10]}, ", ")))
181 print(string.format(" Length: %d", #seq))
182 end
183
184 print()
185 print("[SUCCESS] Batch processing completed successfully!")
186 print()
187
188 -- Cleanup
189 vk.shutdown(ctx)
190end
191-- }}}
192
193-- Run main with error handling
194local success, err = pcall(main)
195if not success then
196 print("\n[ERROR] " .. tostring(err))
197 os.exit(1)
198end
199