libs/vulkan-compute/test_diversity_simple.c
1/* test_diversity_simple.c - Simple test of diversity sequence generation
2 *
3 * Tests the diversity sequence algorithm with a small dataset.
4 * NOTE: This is a simplified test - the full algorithm needs centroid updates.
5 */
6
7#include "include/vk_diversity.h"
8#include <stdio.h>
9#include <stdlib.h>
10#include <math.h>
11
12#define EMBEDDING_DIM 768
13#define NUM_POEMS 100
14#define START_POEM 0
15
16/* Generate test embeddings with some structure */
17void generate_structured_embeddings(float* embeddings, int num_poems, int dim) {
18 for (int i = 0; i < num_poems; i++) {
19 for (int d = 0; d < dim; d++) {
20 /* Make embeddings somewhat clustered */
21 float base = (float)i / num_poems;
22 float variation = ((float)rand() / RAND_MAX - 0.5f) * 0.1f;
23 embeddings[i * dim + d] = base + variation;
24 }
25 }
26}
27
28int main(void) {
29 printf("=== Diversity Sequence Test ===\n\n");
30 printf("Configuration:\n");
31 printf(" Poems: %d\n", NUM_POEMS);
32 printf(" Dimensions: %d\n", EMBEDDING_DIM);
33 printf(" Start poem: %d\n", START_POEM);
34
35 srand(42);
36
37 /* Generate test data */
38 printf("\n[1] Generating test embeddings...\n");
39 float* embeddings = malloc(NUM_POEMS * EMBEDDING_DIM * sizeof(float));
40 if (!embeddings) {
41 fprintf(stderr, "ERROR: Memory allocation failed\n");
42 return 1;
43 }
44
45 generate_structured_embeddings(embeddings, NUM_POEMS, EMBEDDING_DIM);
46 printf(" [OK] Generated %d embeddings\n", NUM_POEMS);
47
48 /* Initialize Vulkan */
49 printf("\n[2] Initializing Vulkan...\n");
50 VkComputeContext* ctx = vkc_init(false);
51 if (!ctx) {
52 fprintf(stderr, "ERROR: Failed to initialize Vulkan\n");
53 free(embeddings);
54 return 1;
55 }
56 printf(" [OK] Vulkan initialized\n");
57
58 /* Initialize diversity context */
59 printf("\n[3] Initializing diversity context...\n");
60 VkDiversityContext* div_ctx = vkd_init(ctx, embeddings, NUM_POEMS, EMBEDDING_DIM);
61 if (!div_ctx) {
62 fprintf(stderr, "ERROR: Failed to initialize diversity context\n");
63 vkc_destroy(ctx);
64 free(embeddings);
65 return 1;
66 }
67
68 /* Compute diversity sequence */
69 printf("\n[4] Computing diversity sequence...\n");
70 uint32_t* sequence = malloc(NUM_POEMS * sizeof(uint32_t));
71 if (!sequence) {
72 fprintf(stderr, "ERROR: Memory allocation failed\n");
73 vkd_destroy(div_ctx);
74 vkc_destroy(ctx);
75 free(embeddings);
76 return 1;
77 }
78
79 VkComputeResult result = vkd_compute_sequence(div_ctx, START_POEM, sequence);
80 if (result != VKC_SUCCESS) {
81 fprintf(stderr, "ERROR: Diversity sequence computation failed\n");
82 free(sequence);
83 vkd_destroy(div_ctx);
84 vkc_destroy(ctx);
85 free(embeddings);
86 return 1;
87 }
88
89 /* Verify sequence */
90 printf("\n[5] Verifying sequence...\n");
91 bool valid = true;
92
93 /* Check that sequence starts with start_poem */
94 if (sequence[0] != START_POEM) {
95 fprintf(stderr, " ERROR: Sequence doesn't start with start_poem\n");
96 valid = false;
97 }
98
99 /* Check that all poems appear exactly once */
100 uint8_t* seen = calloc(NUM_POEMS, sizeof(uint8_t));
101 for (int i = 0; i < NUM_POEMS; i++) {
102 uint32_t poem_id = sequence[i];
103 if (poem_id >= NUM_POEMS) {
104 fprintf(stderr, " ERROR: Invalid poem ID: %u\n", poem_id);
105 valid = false;
106 break;
107 }
108 if (seen[poem_id]) {
109 fprintf(stderr, " ERROR: Duplicate poem ID: %u\n", poem_id);
110 valid = false;
111 break;
112 }
113 seen[poem_id] = 1;
114 }
115
116 free(seen);
117
118 if (valid) {
119 printf(" [OK] Sequence is valid\n");
120 printf("\n First 10 poems in sequence:\n");
121 for (int i = 0; i < 10 && i < NUM_POEMS; i++) {
122 printf(" [%d] → poem %u\n", i, sequence[i]);
123 }
124 } else {
125 printf(" [FAILED] Sequence validation failed\n");
126 }
127
128 /* Cleanup */
129 printf("\n[6] Cleaning up...\n");
130 free(sequence);
131 free(embeddings);
132 vkd_destroy(div_ctx);
133 vkc_destroy(ctx);
134
135 if (valid) {
136 printf("\n[SUCCESS] Diversity sequence test passed!\n");
137 printf("\nNOTE: This is a simplified test. Full algorithm requires centroid updates.\n");
138 return 0;
139 } else {
140 printf("\n[FAILED] Test failed\n");
141 return 1;
142 }
143}
144