libs/fuzzy-computing.lua

556 lines

1
2local M = {}
3local dkjson = require("libs/dkjson")
4local inference_config = require("libs/inference-server-config")
5local text_chunking = require("libs/text-chunking")
6
7-- Issue 10-050: bounds on how much work rides in ONE /v1/embeddings request.
8-- IMPORTANT: with llama-server's default --parallel 1 there is a single slot, so
9-- an array request is processed as N SEQUENTIAL tasks, not one fused forward
10-- pass. So the win from batching is fewer HTTP/curl round trips, not GPU
11-- parallelism — which means a request must be bounded by total WORK, not just
12-- item count. A flat "16 items" is wrong: 16 tiny words is trivial, but 16
13-- max-size chunks is ~16× a single chunk's load and can outrun the client
14-- timeout (this is exactly what stalled the giant-poem tail of the corpus).
15--
16-- So we bound each request by BOTH a token budget and a hard item cap, and pack
17-- chunks greedily up to whichever comes first.
18M.REQUEST_TOKEN_BUDGET = 4000 -- est. tokens packed into one request
19M.BATCH_SIZE = 16 -- hard cap on array length per request
20
21-- Process-unique counter for temp filenames. os.time() alone collides under
22-- batching (it only ticks once per second, but we fire many requests per
23-- second), which would let two concurrent requests clobber each other's
24-- request/response files. Bumping a counter per call makes each name unique.
25local request_counter = 0
26local function unique_tmp_path(label)
27 request_counter = request_counter + 1
28 return string.format("/tmp/%s_%d_%d.json", label, os.time(), request_counter)
29end
30
31-- {{{ function M.sanitize_utf8(s)
32-- Strip invalid UTF-8 byte sequences from a string, returning (clean, removed).
33-- Why this is mandatory: llama-server serializes responses with nlohmann::json,
34-- which THROWS on an invalid UTF-8 byte — and the exception is uncaught, so a
35-- single bad byte (e.g. a lone 0xB5 from a PDF mis-saved as a .txt note)
36-- terminates the entire server process. We cannot fix llama.cpp, so we guarantee
37-- never to send it malformed text. This walks the bytes, keeps only well-formed
38-- ASCII / 2-/3-/4-byte sequences, and drops anything else.
39function M.sanitize_utf8(s)
40 if type(s) ~= "string" then
41 return s, 0
42 end
43 local out, removed = {}, 0
44 local i, n = 1, #s
45 while i <= n do
46 local c = s:byte(i)
47 -- len, plus the allowed range (lo2..hi2) for the FIRST continuation byte.
48 -- That first-byte range is what enforces STRICT UTF-8 — rejecting overlong
49 -- encodings, UTF-16 surrogates, and code points above U+10FFFF. A merely
50 -- "structural" check (any 0x80-0xBF continuation) is NOT enough: nlohmann
51 -- (and the server) reject those, and a binary blob is full of them.
52 local len, lo2, hi2
53 if c < 0x80 then len = 1
54 elseif c >= 0xC2 and c <= 0xDF then len, lo2, hi2 = 2, 0x80, 0xBF
55 elseif c == 0xE0 then len, lo2, hi2 = 3, 0xA0, 0xBF -- no overlong
56 elseif c >= 0xE1 and c <= 0xEC then len, lo2, hi2 = 3, 0x80, 0xBF
57 elseif c == 0xED then len, lo2, hi2 = 3, 0x80, 0x9F -- no surrogates
58 elseif c >= 0xEE and c <= 0xEF then len, lo2, hi2 = 3, 0x80, 0xBF
59 elseif c == 0xF0 then len, lo2, hi2 = 4, 0x90, 0xBF -- no overlong
60 elseif c >= 0xF1 and c <= 0xF3 then len, lo2, hi2 = 4, 0x80, 0xBF
61 elseif c == 0xF4 then len, lo2, hi2 = 4, 0x80, 0x8F -- <= U+10FFFF
62 else len = 0 end -- invalid lead (0x80-0xC1 incl. overlong 2-byte, 0xF5-0xFF)
63
64 local valid = false
65 if len == 1 then
66 valid = true
67 elseif len and len > 1 and (i + len - 1 <= n) then
68 local b2 = s:byte(i + 1)
69 if b2 >= lo2 and b2 <= hi2 then
70 valid = true
71 for k = 2, len - 1 do
72 local cc = s:byte(i + k)
73 if cc < 0x80 or cc > 0xBF then valid = false; break end
74 end
75 end
76 end
77
78 if valid then
79 out[#out + 1] = string.sub(s, i, i + len - 1)
80 i = i + len
81 else
82 removed = removed + 1
83 i = i + 1 -- drop one bad byte, resync on the next
84 end
85 end
86 return table.concat(out), removed
87end
88-- }}}
89
90function M.generate(context, model) -- {{{ (DEPRECATED - use M.get_embedding instead)
91 local request_body = {
92 model = model,
93 messages = context,
94 stream = false
95 }
96 local json_data = dkjson.encode(request_body)
97
98 -- Create temporary files for curl communication
99 local input_file = "/tmp/llm_request_" .. os.time() .. ".json"
100 local output_file = "/tmp/llm_response_" .. os.time() .. ".json"
101
102 -- Write request to file
103 local f = io.open(input_file, "w")
104 f:write(json_data)
105 f:close()
106
107 -- Make curl request
108 -- Issue 10-017: Use build_host_url() instead of deprecated OLLAMA_ENDPOINT
109 local curl_cmd = string.format(
110 "curl -s -X POST %s/api/chat -H 'Content-Type: application/json' -d @%s > %s",
111 inference_config.build_host_url(), input_file, output_file
112 )
113
114 os.execute(curl_cmd)
115
116 -- Read response
117 local response_file = io.open(output_file, "r")
118 if not response_file then
119 -- Cleanup
120 os.remove(input_file)
121 return nil
122 end
123
124 local response_text = response_file:read("*all")
125 response_file:close()
126
127 -- Cleanup
128 os.remove(input_file)
129 os.remove(output_file)
130
131 local response = dkjson.decode(response_text)
132 if response and response.message then
133 return response.message.content
134 end
135
136 return nil
137end -- }}}
138
139-- Embedding via the configured inference server's OpenAI-compatible
140-- /v1/embeddings endpoint. Issue 10-049: migrated from Ollama's
141-- /api/embeddings shape, which used "prompt" in the request and
142-- returned the vector directly under "embedding". The OpenAI shape
143-- uses "input" and nests the vector under data[N].embedding.
144--
145-- Issue 10-050: the OpenAI shape also accepts an ARRAY of inputs in one
146-- request, embedded in one GPU forward pass and returned as data[0..N-1].
147-- get_embeddings_batch sends N inputs per round trip instead of one — an
148-- order-of-magnitude throughput win over the old one-at-a-time loop. Each
149-- input still gets the per-item task prefix (the prefix is per-item, not
150-- per-batch).
151-- endpoint and format_fn are optional. They exist so a caller that maintains
152-- its OWN inference-server-config instance (e.g. src/similarity-engine.lua,
153-- which requires the module under a different key and thus a different state
154-- object) can pass its already-resolved endpoint URL and prompt-prefix function
155-- through, rather than relying on THIS module's separate config state. Omitting
156-- them falls back to this module's instance, which is what the simpler call
157-- sites (colors, words) want.
158function M.get_embeddings_batch(texts, model, endpoint, format_fn) -- {{{
159 if type(texts) ~= "table" or #texts == 0 then
160 return {}, nil
161 end
162
163 endpoint = endpoint or inference_config.build_host_url()
164 format_fn = format_fn or inference_config.format_embedding_prompt
165
166 -- Prefix every input individually (e.g. "clustering: " for nomic v1.5+), after
167 -- stripping any invalid UTF-8 (which would crash the server — see
168 -- M.sanitize_utf8). Warn once per batch if anything was stripped, so corrupt /
169 -- binary poem content (e.g. a PDF mis-saved as a .txt) is visible, not silent.
170 local prefixed = {}
171 local removed_total = 0
172 for i = 1, #texts do
173 local clean, removed = M.sanitize_utf8(texts[i])
174 removed_total = removed_total + removed
175 prefixed[i] = format_fn(clean)
176 end
177 if removed_total > 0 then
178 io.stderr:write(string.format(
179 "[WARN] sanitize_utf8: stripped %d invalid UTF-8 byte(s) from this "
180 .. "embedding batch (corrupt/binary poem content)\n", removed_total))
181 end
182
183 local request_body = { model = model, input = prefixed }
184 local json_data = dkjson.encode(request_body)
185
186 local input_file = unique_tmp_path("embedding_batch_request")
187 local output_file = unique_tmp_path("embedding_batch_response")
188
189 local f = io.open(input_file, "w")
190 if not f then
191 return nil, "file_error"
192 end
193 f:write(json_data)
194 f:close()
195
196 -- --max-time 180s headroom: with --parallel 1 a request's inputs run as
197 -- sequential tasks, so a (now token-bounded) batch should finish well inside
198 -- this. If it ever doesn't, the caller bisects rather than just failing.
199 local curl_cmd = string.format(
200 "curl -s --connect-timeout 10 --max-time 180 -X POST %s/v1/embeddings " ..
201 "-H 'Content-Type: application/json' -d @%s > %s",
202 endpoint, input_file, output_file
203 )
204 os.execute(curl_cmd)
205
206 local response_file = io.open(output_file, "r")
207 if not response_file then
208 os.remove(input_file)
209 return nil, "no_response"
210 end
211 local response_text = response_file:read("*all")
212 response_file:close()
213 os.remove(input_file)
214 os.remove(output_file)
215
216 local response = dkjson.decode(response_text)
217 if not (response and response.data) then
218 -- Surface the server's own error text (e.g. a 422 "input too large")
219 -- truncated, so the caller can decide whether to bisect/chunk and retry.
220 return nil, "parse_error: " .. (tostring(response_text):sub(1, 200))
221 end
222
223 -- Place each returned vector at the slot its .index names. OpenAI's index is
224 -- 0-based and the server is free to return data out of order, so we trust
225 -- .index rather than array position; we fall back to sequential order only
226 -- if the field is absent. Result is aligned 1:1 with `texts` (nil where a
227 -- vector is missing, so the caller can single-retry just that slot).
228 local out = {}
229 local next_slot = 1
230 for _, item in ipairs(response.data) do
231 local slot
232 if item.index ~= nil then
233 slot = item.index + 1
234 else
235 slot = next_slot
236 end
237 out[slot] = item.embedding
238 next_slot = next_slot + 1
239 end
240 return out, nil
241end -- }}}
242
243-- get_embedding: single-input convenience shim over get_embeddings_batch.
244-- Issue 10-050 collapsed the old standalone implementation into this one-liner
245-- so there is exactly one code path to the embedding endpoint. Returns the bare
246-- vector (or nil) to preserve the original call contract.
247function M.get_embedding(text, model) -- {{{
248 local vectors = M.get_embeddings_batch({ text }, model)
249 return vectors and vectors[1] or nil
250end -- }}}
251
252-- {{{ function M.tokenize_count(text, endpoint)
253-- Exact token count of `text` under the loaded model's tokenizer, via
254-- llama-server's POST /tokenize ({"content": ...} -> {"tokens": [...]}). Returns
255-- the token count, or nil on any failure (server down, parse error) so the
256-- caller can fall back. Issue 10-050: used to chunk by real tokens instead of a
257-- char heuristic, eliminating truncation risk on dense/non-English text.
258function M.tokenize_count(text, endpoint)
259 endpoint = endpoint or inference_config.build_host_url()
260 -- Never let a malformed byte reach the server (it would crash it); see
261 -- M.sanitize_utf8. Silent here — get_embeddings_batch emits the warning so we
262 -- do not spam one per chunk per poem.
263 text = M.sanitize_utf8(text)
264 local input_file = unique_tmp_path("tokenize_request")
265 local output_file = unique_tmp_path("tokenize_response")
266
267 local f = io.open(input_file, "w")
268 if not f then
269 return nil
270 end
271 f:write(dkjson.encode({ content = text }))
272 f:close()
273
274 local cmd = string.format(
275 "curl -s --connect-timeout 10 --max-time 60 -X POST %s/tokenize " ..
276 "-H 'Content-Type: application/json' -d @%s > %s",
277 endpoint, input_file, output_file)
278 os.execute(cmd)
279
280 local rf = io.open(output_file, "r")
281 if not rf then
282 os.remove(input_file)
283 return nil
284 end
285 local resp = rf:read("*a")
286 rf:close()
287 os.remove(input_file)
288 os.remove(output_file)
289
290 local parsed = dkjson.decode(resp)
291 if parsed and type(parsed.tokens) == "table" then
292 return #parsed.tokens
293 end
294 return nil
295end
296-- }}}
297
298-- The embedding model's hard limits, used to compute an EXACT per-chunk token
299-- budget (no guessed headroom). nomic-embed-text v1.5 was trained at 2048
300-- tokens; beyond that the model truncates. BERT wraps every input as
301-- [CLS] ... [SEP], so 2 tokens of every request are structural specials. These
302-- are model-specific; a model swap means regenerating embeddings anyway.
303M.MODEL_CONTEXT_TOKENS = 2048
304M.EMBED_SPECIAL_TOKENS = 2
305
306-- {{{ function M.make_token_counter(endpoint)
307-- Returns a count_fn(string)->token_count bound to `endpoint`, for the chunker.
308-- It raises rather than returning nil if /tokenize is unreachable: there is no
309-- safe way to size a chunk without the real count, and a fallback estimate would
310-- only mask the failure (the embed call would fail next anyway). Fail loud.
311function M.make_token_counter(endpoint)
312 return function(s)
313 local n = M.tokenize_count(s, endpoint)
314 if not n then
315 error("make_token_counter: /tokenize unreachable at "
316 .. tostring(endpoint or inference_config.build_host_url())
317 .. " — cannot size chunks safely")
318 end
319 return n
320 end
321end
322-- }}}
323
324-- {{{ function M.embedding_chunk_budget(endpoint, format_fn)
325-- The EXACT per-chunk token budget: model context, minus the 2 BERT specials,
326-- minus the per-input task prefix tokenized once (format_fn('') yields just the
327-- prefix). A chunk sized to this budget, once prefixed and wrapped in specials,
328-- lands exactly on the model's context limit — no truncation, no guesswork.
329function M.embedding_chunk_budget(endpoint, format_fn)
330 local prefix = (format_fn and format_fn("")) or ""
331 local prefix_tokens = 0
332 if #prefix > 0 then
333 prefix_tokens = M.tokenize_count(prefix, endpoint)
334 if not prefix_tokens then
335 error("embedding_chunk_budget: /tokenize unreachable at "
336 .. tostring(endpoint or inference_config.build_host_url())
337 .. " — cannot size chunks safely")
338 end
339 end
340 return M.MODEL_CONTEXT_TOKENS - M.EMBED_SPECIAL_TOKENS - prefix_tokens
341end
342-- }}}
343
344-- {{{ function M._embed_with_chunking_impl(texts, batch_fn, count_fn, max_tokens, strategy)
345-- Pure orchestration core for "one vector per input text, even when a text is
346-- too long for the model". Separated from the network call (batch_fn) ON PURPOSE
347-- so it can be unit-tested offline with a mock embedder — this is the part with
348-- the fiddly index bookkeeping, and the part a silent bug would hide in.
349--
350-- Pipeline: chunk every text → FLATTEN all chunks into one list (remembering
351-- which chunks belong to which text) → pack the flat list into requests bounded
352-- by a TOKEN budget (not a flat item count) → embed each request → gather each
353-- text's chunk vectors back and RECOMBINE them (length-weighted by chunk char
354-- length) into one vector.
355--
356-- The token budget is what keeps each request light enough to succeed; a failed
357-- request just leaves its slots nil (the caller single-retries those items),
358-- rather than splitting the request further.
359--
360-- texts : array of strings (assumed non-empty; callers handle empties)
361-- batch_fn : function(sub_texts) -> (vectors_aligned, err). Returns nil on a
362-- request failure (timeout, server down, rejected input).
363-- count_fn : function(string) -> exact token count (drives chunk sizing).
364-- max_tokens: per-chunk token ceiling passed to chunk_text_by_tokens.
365-- returns : (out, err). out[i] is text i's combined vector, or nil if its
366-- chunks could not be embedded (caller may single-retry that one).
367-- If NOTHING embedded at all (every request failed — i.e. the
368-- server looks down), returns (nil, "all_requests_failed") so the
369-- caller's network-backoff/threshold logic still triggers.
370function M._embed_with_chunking_impl(texts, batch_fn, count_fn, max_tokens, strategy)
371 -- 1. chunk each text and flatten, recording each text's slice of the flat list.
372 -- chunk_text_by_tokens hands back the EXACT token count of every chunk, so
373 -- request packing below never has to estimate — one source of truth.
374 local flat = {}
375 local flat_tokens = {} -- exact token count of each flat chunk
376 local slices = {} -- slices[i] = { start = , count = , weights = {char lengths} }
377 for i = 1, #texts do
378 local chunks, counts = text_chunking.chunk_text_by_tokens(texts[i], count_fn, max_tokens)
379 local start_pos = #flat + 1
380 local weights = {}
381 for j = 1, #chunks do
382 flat[#flat + 1] = chunks[j]
383 flat_tokens[#flat_tokens + 1] = counts[j]
384 weights[#weights + 1] = #chunks[j]
385 end
386 slices[i] = { start = start_pos, count = #chunks, weights = weights }
387 end
388
389 -- 2. pack the flat chunk list into requests bounded by BOTH a token budget
390 -- and a hard item cap, then embed each request (one request per group).
391 -- Packing by tokens (not a fixed count) keeps a request's work bounded
392 -- regardless of whether its chunks are tiny words or near-max chunks.
393 local flat_vectors = {}
394 local idx = 1
395 while idx <= #flat do
396 local sub = {}
397 local sub_start = idx
398 local budget = 0
399 while idx <= #flat and #sub < M.BATCH_SIZE do
400 local tok = flat_tokens[idx] -- exact count from chunking, not an estimate
401 -- Always take at least one item, even if it alone exceeds the budget
402 -- (a single max chunk is still a legal request — it fits the model).
403 if #sub > 0 and (budget + tok) > M.REQUEST_TOKEN_BUDGET then
404 break
405 end
406 sub[#sub + 1] = flat[idx]
407 budget = budget + tok
408 idx = idx + 1
409 end
410 -- On failure, leave these slots nil and move on. The distribution step
411 -- marks the affected poem(s) nil, and similarity-engine single-retries
412 -- just those poems (whose chunks then go out as their own small,
413 -- token-bounded requests). The token budget — not per-request splitting —
414 -- is what keeps requests light enough to succeed.
415 local vectors = batch_fn(sub)
416 if vectors then
417 for k = 1, #sub do
418 flat_vectors[sub_start + k - 1] = vectors[k]
419 end
420 end
421 end
422
423 -- 3. gather each text's chunk vectors and recombine into one vector
424 local out = {}
425 local any_success = false
426 for i = 1, #texts do
427 local s = slices[i]
428 if s.count == 0 then
429 out[i] = nil -- empty/whitespace text produced no chunks
430 else
431 local chunk_vecs = {}
432 local complete = true
433 for k = 0, s.count - 1 do
434 local v = flat_vectors[s.start + k]
435 if type(v) ~= "table" or #v == 0 then
436 complete = false
437 break
438 end
439 chunk_vecs[#chunk_vecs + 1] = v
440 end
441 if complete then
442 out[i] = text_chunking.combine_chunk_vectors(chunk_vecs, s.weights, strategy)
443 any_success = true
444 else
445 out[i] = nil -- a chunk is missing; let the caller decide to retry
446 end
447 end
448 end
449
450 -- If not a single text embedded, the server is almost certainly unreachable
451 -- (a content problem would fail only specific items). Signal a whole-batch
452 -- failure so the caller backs off / counts it toward its error threshold,
453 -- rather than silently recording every poem as a permanent error.
454 if not any_success and #texts > 0 then
455 return nil, "all_requests_failed"
456 end
457 return out, nil
458end
459-- }}}
460
461-- {{{ function M.embed_texts_with_chunking(texts, model, opts)
462-- Production entry point: same as the impl above but wired to the real
463-- /v1/embeddings batch call. Returns one vector per input text (nil where a
464-- text's embedding could not be produced), or (nil, err) on a transport-level
465-- batch failure so the caller can apply retry/backoff.
466--
467-- opts (all optional): { endpoint=, format_fn=, count_fn=, max_tokens=, strategy= }.
468-- endpoint/format_fn are threaded to get_embeddings_batch so a caller with its
469-- own config instance stays the single source of truth (see that function).
470-- count_fn defaults to the EXACT /tokenize-backed counter, and max_tokens to the
471-- exact computed budget. There is NO estimate fallback: if /tokenize is down the
472-- embed call is down too, so we fail loudly rather than mask it (Issue 10-050).
473function M.embed_texts_with_chunking(texts, model, opts)
474 opts = opts or {}
475 local endpoint = opts.endpoint
476 local count_fn = opts.count_fn or M.make_token_counter(endpoint)
477 local max_tokens = opts.max_tokens or M.embedding_chunk_budget(endpoint, opts.format_fn)
478 return M._embed_with_chunking_impl(
479 texts,
480 function(sub)
481 return M.get_embeddings_batch(sub, model, endpoint, opts.format_fn)
482 end,
483 count_fn,
484 max_tokens,
485 opts.strategy)
486end
487-- }}}
488
489-- Calculate cosine similarity between two embedding vectors
490function M.cosine_similarity(vec1, vec2) -- {{{
491 if not vec1 or not vec2 or #vec1 ~= #vec2 then
492 return 0
493 end
494
495 local dot_product = 0
496 local magnitude1 = 0
497 local magnitude2 = 0
498
499 for i = 1, #vec1 do
500 dot_product = dot_product + (vec1[i] * vec2[i])
501 magnitude1 = magnitude1 + (vec1[i] * vec1[i])
502 magnitude2 = magnitude2 + (vec2[i] * vec2[i])
503 end
504
505 magnitude1 = math.sqrt(magnitude1)
506 magnitude2 = math.sqrt(magnitude2)
507
508 if magnitude1 == 0 or magnitude2 == 0 then
509 return 0
510 end
511
512 return dot_product / (magnitude1 * magnitude2)
513end -- }}}
514
515-- Find the most similar theme from a list of theme embeddings
516function M.find_most_similar_theme(text_embedding, theme_embeddings) -- {{{
517 local best_theme = "neutral"
518 local best_similarity = -1
519
520 for theme, theme_embedding in pairs(theme_embeddings) do
521 local similarity = M.cosine_similarity(text_embedding, theme_embedding)
522 if similarity > best_similarity then
523 best_similarity = similarity
524 best_theme = theme
525 end
526 end
527
528 return best_theme, best_similarity
529end -- }}}
530
531-- Find the most similar theme with frequency-based weighting
532function M.find_most_similar_theme_weighted(text_embedding, theme_embeddings, frequency_weights) -- {{{
533 local best_theme = "neutral"
534 local best_weighted_score = -1
535 local best_raw_similarity = -1
536
537 for theme, theme_embedding in pairs(theme_embeddings) do
538 local raw_similarity = M.cosine_similarity(text_embedding, theme_embedding)
539
540 -- Apply frequency-based weighting
541 local frequency_penalty = frequency_weights[theme] or 0
542 local diversity_boost = math.max(0, 1.0 - (frequency_penalty * 0.1)) -- Reduce by 10% per use
543 local weighted_score = raw_similarity * diversity_boost
544
545 if weighted_score > best_weighted_score then
546 best_weighted_score = weighted_score
547 best_raw_similarity = raw_similarity
548 best_theme = theme
549 end
550 end
551
552 return best_theme, best_raw_similarity, best_weighted_score
553end -- }}}
554
555return M
556