From e290d6d869d4877ce2aeddba77d212504eea9fb6 Mon Sep 17 00:00:00 2001 From: andryyy Date: Sun, 8 Mar 2020 12:25:03 +0100 Subject: [PATCH] [Rspamd] Fix neural.lua --- data/Dockerfiles/rspamd/Dockerfile | 1 + data/Dockerfiles/rspamd/neural.lua | 1456 ++++++++++++++++++++++++++++ docker-compose.yml | 2 +- 3 files changed, 1458 insertions(+), 1 deletion(-) create mode 100644 data/Dockerfiles/rspamd/neural.lua diff --git a/data/Dockerfiles/rspamd/Dockerfile b/data/Dockerfiles/rspamd/Dockerfile index d4e98c2c..e7071865 100644 --- a/data/Dockerfiles/rspamd/Dockerfile +++ b/data/Dockerfiles/rspamd/Dockerfile @@ -23,6 +23,7 @@ RUN apt-get update && apt-get install -y \ && chown _rspamd:_rspamd /run/rspamd COPY settings.conf /etc/rspamd/settings.conf +COPY neural.lua /usr/share/rspamd/plugins/neural.lua COPY docker-entrypoint.sh /docker-entrypoint.sh ENTRYPOINT ["/docker-entrypoint.sh"] diff --git a/data/Dockerfiles/rspamd/neural.lua b/data/Dockerfiles/rspamd/neural.lua new file mode 100644 index 00000000..1897f084 --- /dev/null +++ b/data/Dockerfiles/rspamd/neural.lua @@ -0,0 +1,1456 @@ +--[[ +Copyright (c) 2016, Vsevolod Stakhov + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + + +if confighelp then + return +end + +local rspamd_logger = require "rspamd_logger" +local rspamd_util = require "rspamd_util" +local rspamd_kann = require "rspamd_kann" +local lua_redis = require "lua_redis" +local lua_util = require "lua_util" +local fun = require "fun" +local lua_settings = require "lua_settings" +local meta_functions = require "lua_meta" +local ts = require("tableshape").types +local lua_verdict = require "lua_verdict" +local N = "neural" + +-- Module vars +local default_options = { + train = { + max_trains = 1000, + max_epoch = 1000, + max_usages = 10, + max_iterations = 25, -- Torch style + mse = 0.001, + autotrain = true, + train_prob = 1.0, + learn_threads = 1, + learning_rate = 0.01, + }, + watch_interval = 60.0, + lock_expire = 600, + learning_spawned = false, + ann_expire = 60 * 60 * 24 * 2, -- 2 days + symbol_spam = 'NEURAL_SPAM', + symbol_ham = 'NEURAL_HAM', +} + +local redis_profile_schema = ts.shape{ + digest = ts.string, + symbols = ts.array_of(ts.string), + version = ts.number, + redis_key = ts.string, + distance = ts.number:is_optional(), +} + +-- Rule structure: +-- * static config fields (see `default_options`) +-- * prefix - name or defined prefix +-- * settings - table of settings indexed by settings id, -1 is used when no settings defined + +-- Rule settings element defines elements for specific settings id: +-- * symbols - static symbols profile (defined by config or extracted from symcache) +-- * name - name of settings id +-- * digest - digest of all symbols +-- * ann - dynamic ANN configuration loaded from Redis +-- * train - train data for ANN (e.g. the currently trained ANN) + +-- Settings ANN table is loaded from Redis and represents dynamic profile for ANN +-- Some elements are directly stored in Redis, ANN is, in turn loaded dynamically +-- * version - version of ANN loaded from redis +-- * redis_key - name of ANN key in Redis +-- * symbols - symbols in THIS PARTICULAR ANN (might be different from set.symbols) +-- * distance - distance between set.symbols and set.ann.symbols +-- * ann - kann object + +local settings = { + rules = {}, + prefix = 'rn', -- Neural network default prefix + max_profiles = 3, -- Maximum number of NN profiles stored +} + +local module_config = rspamd_config:get_all_opt("neural") +if not module_config then + -- Legacy + module_config = rspamd_config:get_all_opt("fann_redis") +end + + +-- Lua script that checks if we can store a new training vector +-- Uses the following keys: +-- key1 - ann key +-- key2 - spam or ham +-- key3 - maximum trains +-- key4 - sampling coin (as Redis scripts do not allow math.random calls) +-- returns 1 or 0 + reason: 1 - allow learn, 0 - not allow learn +local redis_lua_script_can_store_train_vec = [[ + local prefix = KEYS[1] + local locked = redis.call('HGET', prefix, 'lock') + if locked then return {tostring(-1),'locked by another process till: ' .. locked} end + local nspam = 0 + local nham = 0 + local lim = tonumber(KEYS[3]) + local coin = tonumber(KEYS[4]) + + local ret = redis.call('LLEN', prefix .. '_spam') + if ret then nspam = tonumber(ret) end + ret = redis.call('LLEN', prefix .. '_ham') + if ret then nham = tonumber(ret) end + + if KEYS[2] == 'spam' then + if nspam <= lim then + if nspam > nham then + -- Apply sampling + local skip_rate = 1.0 - nham / (nspam + 1) + if coin < skip_rate then + return {tostring(-(nspam)),'sampled out with probability ' .. tostring(skip_rate)} + end + end + return {tostring(nspam),'can learn'} + else -- Enough learns + return {tostring(-(nspam)),'too many spam samples'} + end + else + if nham <= lim then + if nham > nspam then + -- Apply sampling + local skip_rate = 1.0 - nspam / (nham + 1) + if coin < skip_rate then + return {tostring(-(nham)),'sampled out with probability ' .. tostring(skip_rate)} + end + end + return {tostring(nham),'can learn'} + else + return {tostring(-(nham)),'too many ham samples'} + end + end + + return {tostring(-1),'bad input'} +]] +local redis_can_store_train_vec_id = nil + +-- Lua script to invalidate ANNs by rank +-- Uses the following keys +-- key1 - prefix for keys +-- key2 - number of elements to leave +local redis_lua_script_maybe_invalidate = [[ + local card = redis.call('ZCARD', KEYS[1]) + local lim = tonumber(KEYS[2]) + if card > lim then + local to_delete = redis.call('ZRANGE', KEYS[1], 0, card - lim - 1) + for _,k in ipairs(to_delete) do + local tb = cjson.decode(k) + redis.call('DEL', tb.redis_key) + -- Also train vectors + redis.call('DEL', tb.redis_key .. '_spam') + redis.call('DEL', tb.redis_key .. '_ham') + end + redis.call('ZREMRANGEBYRANK', KEYS[1], 0, card - lim - 1) + return to_delete + else + return {} + end +]] +local redis_maybe_invalidate_id = nil + +-- Lua script to invalidate ANN from redis +-- Uses the following keys +-- key1 - prefix for keys +-- key2 - current time +-- key3 - key expire +-- key4 - hostname +local redis_lua_script_maybe_lock = [[ + local locked = redis.call('HGET', KEYS[1], 'lock') + local now = tonumber(KEYS[2]) + if locked then + locked = tonumber(locked) + local expire = tonumber(KEYS[3]) + if now > locked and (now - locked) < expire then + return {tostring(locked), redis.call('HGET', KEYS[1], 'hostname')} + end + end + redis.call('HSET', KEYS[1], 'lock', tostring(now)) + redis.call('HSET', KEYS[1], 'hostname', KEYS[4]) + return 1 +]] +local redis_maybe_lock_id = nil + +-- Lua script to save and unlock ANN in redis +-- Uses the following keys +-- key1 - prefix for ANN +-- key2 - prefix for profile +-- key3 - compressed ANN +-- key4 - profile as JSON +-- key5 - expire in seconds +-- key6 - current time +-- key7 - old key +local redis_lua_script_save_unlock = [[ + local now = tonumber(KEYS[6]) + redis.call('ZADD', KEYS[2], now, KEYS[4]) + redis.call('HSET', KEYS[1], 'ann', KEYS[3]) + redis.call('DEL', KEYS[1] .. '_spam') + redis.call('DEL', KEYS[1] .. '_ham') + redis.call('HDEL', KEYS[1], 'lock') + redis.call('HDEL', KEYS[7], 'lock') + redis.call('EXPIRE', KEYS[1], tonumber(KEYS[5])) + return 1 +]] +local redis_save_unlock_id = nil + +local redis_params + +local function load_scripts(params) + redis_can_store_train_vec_id = lua_redis.add_redis_script(redis_lua_script_can_store_train_vec, + params) + redis_maybe_invalidate_id = lua_redis.add_redis_script(redis_lua_script_maybe_invalidate, + params) + redis_maybe_lock_id = lua_redis.add_redis_script(redis_lua_script_maybe_lock, + params) + redis_save_unlock_id = lua_redis.add_redis_script(redis_lua_script_save_unlock, + params) +end + +local function result_to_vector(task, profile) + if not profile.zeros then + -- Fill zeros vector + local zeros = {} + for i=1,meta_functions.count_metatokens() do + zeros[i] = 0.0 + end + for _,_ in ipairs(profile.symbols) do + zeros[#zeros + 1] = 0.0 + end + profile.zeros = zeros + end + + local vec = lua_util.shallowcopy(profile.zeros) + local mt = meta_functions.rspamd_gen_metatokens(task) + + for i,v in ipairs(mt) do + vec[i] = v + end + + task:process_ann_tokens(profile.symbols, vec, #mt, 0.1) + + return vec +end + +-- Used to generate new ANN key for specific profile +local function new_ann_key(rule, set, version) + local ann_key = string.format('%s_%s_%s_%s_%s', settings.prefix, + rule.prefix, set.name, set.digest:sub(1, 8), tostring(version)) + + return ann_key +end + +-- Extract settings element for a specific settings id +local function get_rule_settings(task, rule) + local sid = task:get_settings_id() or -1 + + local set = rule.settings[sid] + + if not set then return nil end + + while type(set) == 'number' do + -- Reference to another settings! + set = rule.settings[set] + end + + return set +end + +-- Generate redis prefix for specific rule and specific settings +local function redis_ann_prefix(rule, settings_name) + -- We also need to count metatokens: + local n = meta_functions.version + return string.format('%s_%s_%d_%s', + settings.prefix, rule.prefix, n, settings_name) +end + +-- Creates and stores ANN profile in Redis +local function new_ann_profile(task, rule, set, version) + local ann_key = new_ann_key(rule, set, version) + + local profile = { + symbols = set.symbols, + redis_key = ann_key, + version = version, + digest = set.digest, + distance = 0 -- Since we are using our own profile + } + + local ucl = require "ucl" + local profile_serialized = ucl.to_format(profile, 'json-compact', true) + + local function add_cb(err, _) + if err then + rspamd_logger.errx(task, 'cannot store ANN profile for %s:%s at %s : %s', + rule.prefix, set.name, profile.redis_key, err) + else + rspamd_logger.infox(task, 'created new ANN profile for %s:%s, data stored at prefix %s', + rule.prefix, set.name, profile.redis_key) + end + end + + lua_redis.redis_make_request(task, + rule.redis, + nil, + true, -- is write + add_cb, --callback + 'ZADD', -- command + {set.prefix, tostring(rspamd_util.get_time()), profile_serialized} + ) + + return profile +end + + +-- ANN filter function, used to insert scores based on the existing symbols +local function ann_scores_filter(task) + + for _,rule in pairs(settings.rules) do + local sid = task:get_settings_id() or -1 + local ann + local profile + + local set = get_rule_settings(task, rule) + if set then + if set.ann then + ann = set.ann.ann + profile = set.ann + else + lua_util.debugm(N, task, 'no ann loaded for %s:%s', + rule.prefix, set.name) + end + else + lua_util.debugm(N, task, 'no ann defined in %s for settings id %s', + rule.prefix, sid) + end + + if ann then + local vec = result_to_vector(task, profile) + + local score + local out = ann:apply1(vec) + score = out[1] + + local symscore = string.format('%.3f', score) + lua_util.debugm(N, task, '%s:%s:%s ann score: %s', + rule.prefix, set.name, set.ann.version, symscore) + + if score > 0 then + local result = score + task:insert_result(rule.symbol_spam, result, symscore) + else + local result = -(score) + task:insert_result(rule.symbol_ham, result, symscore) + end + end + end +end + +local function create_ann(n, nlayers) + -- We ignore number of layers so far when using kann + local nhidden = math.floor((n + 1) / 2) + local t = rspamd_kann.layer.input(n) + t = rspamd_kann.transform.relu(t) + t = rspamd_kann.transform.tanh(rspamd_kann.layer.dense(t, nhidden)); + t = rspamd_kann.layer.cost(t, 1, rspamd_kann.cost.mse) + return rspamd_kann.new.kann(t) +end + + +local function ann_push_task_result(rule, task, verdict, score, set) + local train_opts = rule.train + local learn_spam, learn_ham + local skip_reason = 'unknown' + + if train_opts.autotrain then + if train_opts.spam_score then + learn_spam = score >= train_opts.spam_score + + if not learn_spam then + skip_reason = string.format('score < spam_score: %f < %f', + score, train_opts.spam_score) + end + else + learn_spam = verdict == 'spam' or verdict == 'junk' + + if not learn_spam then + skip_reason = string.format('verdict: %s', + verdict) + end + end + + if train_opts.ham_score then + learn_ham = score <= train_opts.ham_score + if not learn_ham then + skip_reason = string.format('score > ham_score: %f > %f', + score, train_opts.ham_score) + end + else + learn_ham = verdict == 'ham' + + if not learn_ham then + skip_reason = string.format('verdict: %s', + verdict) + end + end + else + -- Train by request header + local hdr = task:get_request_header('ANN-Train') + + if hdr then + if hdr:lower() == 'spam' then + learn_spam = true + elseif hdr:lower() == 'ham' then + learn_ham = true + else + skip_reason = string.format('no explicit header') + end + end + end + + + if learn_spam or learn_ham then + local learn_type + if learn_spam then learn_type = 'spam' else learn_type = 'ham' end + + local function can_train_cb(err, data) + if not err and type(data) == 'table' then + local nsamples,reason = tonumber(data[1]),data[2] + + if nsamples >= 0 then + local coin = math.random() + + if coin < 1.0 - train_opts.train_prob then + rspamd_logger.infox(task, 'probabilistically skip sample: %s', coin) + return + end + + local vec = result_to_vector(task, set) + + local str = rspamd_util.zstd_compress(table.concat(vec, ';')) + local target_key = set.ann.redis_key .. '_' .. learn_type + + local function learn_vec_cb(_err) + if _err then + rspamd_logger.errx(task, 'cannot store train vector for %s:%s: %s', + rule.prefix, set.name, _err) + else + lua_util.debugm(N, task, + "add train data for ANN rule " .. + "%s:%s, save %s vector of %s elts in %s key; %s bytes compressed", + rule.prefix, set.name, learn_type, #vec, target_key, #str) + end + end + + lua_redis.redis_make_request(task, + rule.redis, + nil, + true, -- is write + learn_vec_cb, --callback + 'LPUSH', -- command + { target_key, str } -- arguments + ) + else + -- Negative result returned + rspamd_logger.infox(task, "cannot learn %s ANN %s:%s; redis_key: %s: %s (%s vectors stored)", + learn_type, rule.prefix, set.name, set.ann.redis_key, reason, -tonumber(nsamples)) + end + else + if err then + rspamd_logger.errx(task, 'cannot check if we can train %s:%s : %s', + rule.prefix, set.name, err) + else + rspamd_logger.errx(task, 'cannot check if we can train %s:%s : type of Redis key %s is %s, expected table' .. + 'please remove this key from Redis manually if you perform upgrade from the previous version', + rule.prefix, set.name, set.ann.redis_key, type(data)) + end + end + end + + -- Check if we can learn + if set.can_store_vectors then + if not set.ann then + -- Need to create or load a profile corresponding to the current configuration + set.ann = new_ann_profile(task, rule, set, 0) + lua_util.debugm(N, task, + 'requested new profile for %s, set.ann is missing', + set.name) + end + + lua_redis.exec_redis_script(redis_can_store_train_vec_id, + {task = task, is_write = true}, + can_train_cb, + { + set.ann.redis_key, + learn_type, + tostring(train_opts.max_trains), + tostring(math.random()), + }) + else + lua_util.debugm(N, task, + 'do not push data: train condition not satisfied; reason: not checked existing ANNs') + end + else + lua_util.debugm(N, task, + 'do not push data to key %s: train condition not satisfied; reason: %s', + (set.ann or {}).redis_key, + skip_reason) + end +end + +--- Offline training logic + +-- Closure generator for unlock function +local function gen_unlock_cb(rule, set, ann_key) + return function (err) + if err then + rspamd_logger.errx(rspamd_config, 'cannot unlock ANN %s:%s at %s from redis: %s', + rule.prefix, set.name, ann_key, err) + else + lua_util.debugm(N, rspamd_config, 'unlocked ANN %s:%s at %s', + rule.prefix, set.name, ann_key) + end + end +end + +-- This function is intended to extend lock for ANN during training +-- It registers periodic that increases locked key each 30 seconds unless +-- `set.learning_spawned` is set to `true` +local function register_lock_extender(rule, set, ev_base, ann_key) + rspamd_config:add_periodic(ev_base, 30.0, + function() + local function redis_lock_extend_cb(_err, _) + if _err then + rspamd_logger.errx(rspamd_config, 'cannot lock ANN %s from redis: %s', + ann_key, _err) + else + rspamd_logger.infox(rspamd_config, 'extend lock for ANN %s for 30 seconds', + ann_key) + end + end + + if set.learning_spawned then + lua_redis.redis_make_request_taskless(ev_base, + rspamd_config, + rule.redis, + nil, + true, -- is write + redis_lock_extend_cb, --callback + 'HINCRBY', -- command + {ann_key, 'lock', '30'} + ) + else + lua_util.debugm(N, rspamd_config, "stop lock extension as learning_spawned is false") + return false -- do not plan any more updates + end + + return true + end + ) +end + +-- This function receives training vectors, checks them, spawn learning and saves ANN in Redis +local function spawn_train(worker, ev_base, rule, set, ann_key, ham_vec, spam_vec) + -- Check training data sanity + -- Now we need to join inputs and create the appropriate test vectors + local n = #set.symbols + + meta_functions.rspamd_count_metatokens() + + -- Now we can train ann + local train_ann = create_ann(n, 3) + + if #ham_vec + #spam_vec < rule.train.max_trains / 2 then + -- Invalidate ANN as it is definitely invalid + -- TODO: add invalidation + assert(false) + else + local inputs, outputs = {}, {} + + -- Used to show sparsed vectors in a convenient format (for debugging only) + local function debug_vec(t) + local ret = {} + for i,v in ipairs(t) do + if v ~= 0 then + ret[#ret + 1] = string.format('%d=%.2f', i, v) + end + end + + return ret + end + + -- Make training set by joining vectors + -- KANN automatically shuffles those samples + -- 1.0 is used for spam and -1.0 is used for ham + -- It implies that output layer can express that (e.g. tanh output) + for _,e in ipairs(spam_vec) do + inputs[#inputs + 1] = e + outputs[#outputs + 1] = {1.0} + --rspamd_logger.debugm(N, rspamd_config, 'spam vector: %s', debug_vec(e)) + end + for _,e in ipairs(ham_vec) do + inputs[#inputs + 1] = e + outputs[#outputs + 1] = {-1.0} + --rspamd_logger.debugm(N, rspamd_config, 'ham vector: %s', debug_vec(e)) + end + + -- Called in child process + local function train() + local log_thresh = rule.train.max_iterations / 10 + local seen_nan = false + + local function train_cb(iter, train_cost, value_cost) + if (iter * (rule.train.max_iterations / log_thresh)) % (rule.train.max_iterations) == 0 then + if train_cost ~= train_cost and not seen_nan then + -- We have nan :( try to log lot's of stuff to dig into a problem + seen_nan = true + rspamd_logger.errx(rspamd_config, 'ANN %s:%s: train error: observed nan in error cost!; value cost = %s', + rule.prefix, set.name, + value_cost) + for i,e in ipairs(inputs) do + lua_util.debugm(N, rspamd_config, 'train vector %s -> %s', + debug_vec(e), outputs[i][1]) + end + end + + rspamd_logger.infox(rspamd_config, + "ANN %s:%s: learned from %s redis key in %s iterations, error: %s, value cost: %s", + rule.prefix, set.name, + ann_key, + iter, + train_cost, + value_cost) + end + end + + train_ann:train1(inputs, outputs, { + lr = rule.train.learning_rate, + max_epoch = rule.train.max_iterations, + cb = train_cb, + }) + + if not seen_nan then + local out = train_ann:save() + return out + else + return nil + end + end + + set.learning_spawned = true + + local function redis_save_cb(err) + if err then + rspamd_logger.errx(rspamd_config, 'cannot save ANN %s:%s to redis key %s: %s', + rule.prefix, set.name, ann_key, err) + lua_redis.redis_make_request_taskless(ev_base, + rspamd_config, + rule.redis, + nil, + false, -- is write + gen_unlock_cb(rule, set, ann_key), --callback + 'HDEL', -- command + {ann_key, 'lock'} + ) + else + rspamd_logger.infox(rspamd_config, 'saved ANN %s:%s to redis: %s', + rule.prefix, set.name, set.ann.redis_key) + end + end + + local function ann_trained(err, data) + set.learning_spawned = false + if err then + rspamd_logger.errx(rspamd_config, 'cannot train ANN %s:%s : %s', + rule.prefix, set.name, err) + lua_redis.redis_make_request_taskless(ev_base, + rspamd_config, + rule.redis, + nil, + true, -- is write + gen_unlock_cb(rule, set, ann_key), --callback + 'HDEL', -- command + {ann_key, 'lock'} + ) + else + local ann_data = rspamd_util.zstd_compress(data) + if not set.ann then + set.ann = { + symbols = set.symbols, + distance = 0, + digest = set.digest, + redis_key = ann_key, + } + end + -- Deserialise ANN from the child process + ann_trained = rspamd_kann.load(data) + local version = (set.ann.version or 0) + 1 + set.ann.version = version + set.ann.ann = ann_trained + set.ann.symbols = set.symbols + set.ann.redis_key = new_ann_key(rule, set, version) + + local profile = { + symbols = set.symbols, + digest = set.digest, + redis_key = set.ann.redis_key, + version = version + } + + local ucl = require "ucl" + local profile_serialized = ucl.to_format(profile, 'json-compact', true) + + rspamd_logger.infox(rspamd_config, + 'trained ANN %s:%s, %s bytes; redis key: %s (old key %s)', + rule.prefix, set.name, #data, set.ann.redis_key, ann_key) + + lua_redis.exec_redis_script(redis_save_unlock_id, + {ev_base = ev_base, is_write = true}, + redis_save_cb, + {profile.redis_key, + redis_ann_prefix(rule, set.name), + ann_data, + profile_serialized, + tostring(rule.ann_expire), + tostring(os.time()), + ann_key, -- old key to unlock... + }) + end + end + + worker:spawn_process{ + func = train, + on_complete = ann_trained, + proctitle = string.format("ANN train for %s/%s", rule.prefix, set.name), + } + end + -- Spawn learn and register lock extension + set.learning_spawned = true + register_lock_extender(rule, set, ev_base, ann_key) +end + +-- Utility to extract and split saved training vectors to a table of tables +local function process_training_vectors(data) + return fun.totable(fun.map(function(tok) + local _,str = rspamd_util.zstd_decompress(tok) + return fun.totable(fun.map(tonumber, lua_util.str_split(tostring(str), ';'))) + end, data)) +end + +-- This function does the following: +-- * Tries to lock ANN +-- * Loads spam and ham vectors +-- * Spawn learning process +local function do_train_ann(worker, ev_base, rule, set, ann_key) + local spam_elts = {} + local ham_elts = {} + + local function redis_ham_cb(err, data) + if err or type(data) ~= 'table' then + rspamd_logger.errx(rspamd_config, 'cannot get ham tokens for ANN %s from redis: %s', + ann_key, err) + -- Unlock on error + lua_redis.redis_make_request_taskless(ev_base, + rspamd_config, + rule.redis, + nil, + true, -- is write + gen_unlock_cb(rule, set, ann_key), --callback + 'HDEL', -- command + {ann_key, 'lock'} + ) + else + -- Decompress and convert to numbers each training vector + ham_elts = process_training_vectors(data) + spawn_train(worker, ev_base, rule, set, ann_key, ham_elts, spam_elts) + end + end + + -- Spam vectors received + local function redis_spam_cb(err, data) + if err or type(data) ~= 'table' then + rspamd_logger.errx(rspamd_config, 'cannot get spam tokens for ANN %s from redis: %s', + ann_key, err) + -- Unlock ANN on error + lua_redis.redis_make_request_taskless(ev_base, + rspamd_config, + rule.redis, + nil, + true, -- is write + gen_unlock_cb(rule, set, ann_key), --callback + 'HDEL', -- command + {ann_key, 'lock'} + ) + else + -- Decompress and convert to numbers each training vector + spam_elts = process_training_vectors(data) + -- Now get ham vectors... + lua_redis.redis_make_request_taskless(ev_base, + rspamd_config, + rule.redis, + nil, + false, -- is write + redis_ham_cb, --callback + 'LRANGE', -- command + {ann_key .. '_ham', '0', '-1'} + ) + end + end + + local function redis_lock_cb(err, data) + if err then + rspamd_logger.errx(rspamd_config, 'cannot call lock script for ANN %s from redis: %s', + ann_key, err) + elseif type(data) == 'number' and data == 1 then + -- ANN is locked, so we can extract SPAM and HAM vectors and spawn learning + lua_redis.redis_make_request_taskless(ev_base, + rspamd_config, + rule.redis, + nil, + false, -- is write + redis_spam_cb, --callback + 'LRANGE', -- command + {ann_key .. '_spam', '0', '-1'} + ) + + rspamd_logger.infox(rspamd_config, 'lock ANN %s:%s (key name %s) for learning', + rule.prefix, set.name, ann_key) + else + local lock_tm = tonumber(data[1]) + rspamd_logger.infox(rspamd_config, 'do not learn ANN %s:%s (key name %s), ' .. + 'locked by another host %s at %s', rule.prefix, set.name, ann_key, + data[2], os.date('%c', lock_tm)) + end + end + + -- Check if we are already learning this network + if set.learning_spawned then + rspamd_logger.infox(rspamd_config, 'do not learn ANN %s, already learning another ANN', + ann_key) + return + end + + -- Call Redis script that tries to acquire a lock + -- This script returns either a boolean or a pair {'lock_time', 'hostname'} when + -- ANN is locked by another host (or a process, meh) + lua_redis.exec_redis_script(redis_maybe_lock_id, + {ev_base = ev_base, is_write = true}, + redis_lock_cb, + { + ann_key, + tostring(os.time()), + tostring(rule.watch_interval * 2), + rspamd_util.get_hostname() + }) +end + +-- This function loads new ann from Redis +-- This is based on `profile` attribute. +-- ANN is loaded from `profile.redis_key` +-- Rank of `profile` key is also increased, unfortunately, it means that we need to +-- serialize profile one more time and set its rank to the current time +-- set.ann fields are set according to Redis data received +local function load_new_ann(rule, ev_base, set, profile, min_diff) + local ann_key = profile.redis_key + + local function data_cb(err, data) + if err then + rspamd_logger.errx(rspamd_config, 'cannot get ANN data from key: %s; %s', + ann_key, err) + else + if type(data) == 'string' then + local _err,ann_data = rspamd_util.zstd_decompress(data) + local ann + + if _err or not ann_data then + rspamd_logger.errx(rspamd_config, 'cannot decompress ANN for %s from Redis key %s: %s', + rule.prefix .. ':' .. set.name, ann_key, _err) + return + else + ann = rspamd_kann.load(ann_data) + + if ann then + set.ann = { + digest = profile.digest, + version = profile.version, + symbols = profile.symbols, + distance = min_diff, + redis_key = profile.redis_key + } + + local ucl = require "ucl" + local profile_serialized = ucl.to_format(profile, 'json-compact', true) + set.ann.ann = ann -- To avoid serialization + + local function rank_cb(_, _) + -- TODO: maybe add some logging + end + -- Also update rank for the loaded ANN to avoid removal + lua_redis.redis_make_request_taskless(ev_base, + rspamd_config, + rule.redis, + nil, + true, -- is write + rank_cb, --callback + 'ZADD', -- command + {set.prefix, tostring(rspamd_util.get_time()), profile_serialized} + ) + rspamd_logger.infox(rspamd_config, 'loaded ANN for %s:%s from %s; %s bytes compressed; version=%s', + rule.prefix, set.name, ann_key, #ann_data, profile.version) + else + rspamd_logger.errx(rspamd_config, 'cannot deserialize ANN for %s:%s from Redis key %s', + rule.prefix, set.name, ann_key) + end + end + else + lua_util.debugm(N, rspamd_config, 'no ANN for %s:%s in Redis key %s', + rule.prefix, set.name, ann_key) + end + end + end + lua_redis.redis_make_request_taskless(ev_base, + rspamd_config, + rule.redis, + nil, + false, -- is write + data_cb, --callback + 'HGET', -- command + {ann_key, 'ann'} -- arguments + ) +end + +-- Used to check an element in Redis serialized as JSON +-- for some specific rule + some specific setting +-- This function tries to load more fresh or more specific ANNs in lieu of +-- the existing ones. +-- Use this function to load ANNs as `callback` parameter for `check_anns` function +local function process_existing_ann(_, ev_base, rule, set, profiles) + local my_symbols = set.symbols + local min_diff = math.huge + local sel_elt + + for _,elt in fun.iter(profiles) do + if elt and elt.symbols then + local dist = lua_util.distance_sorted(elt.symbols, my_symbols) + -- Check distance + if dist < #my_symbols * .3 then + if dist < min_diff then + min_diff = dist + sel_elt = elt + end + end + end + end + + if sel_elt then + -- We can load element from ANN + if set.ann then + -- We have an existing ANN, probably the same... + if set.ann.digest == sel_elt.digest then + -- Same ANN, check version + if set.ann.version < sel_elt.version then + -- Load new ann + rspamd_logger.infox(rspamd_config, 'ann %s is changed, ' .. + 'our version = %s, remote version = %s', + rule.prefix .. ':' .. set.name, + set.ann.version, + sel_elt.version) + load_new_ann(rule, ev_base, set, sel_elt, min_diff) + else + lua_util.debugm(N, rspamd_config, 'ann %s is not changed, ' .. + 'our version = %s, remote version = %s', + rule.prefix .. ':' .. set.name, + set.ann.version, + sel_elt.version) + end + else + -- We have some different ANN, so we need to compare distance + if set.ann.distance > min_diff then + -- Load more specific ANN + rspamd_logger.infox(rspamd_config, 'more specific ann is available for %s, ' .. + 'our distance = %s, remote distance = %s', + rule.prefix .. ':' .. set.name, + set.ann.distance, + min_diff) + load_new_ann(rule, ev_base, set, sel_elt, min_diff) + else + lua_util.debugm(N, rspamd_config, 'ann %s is not changed or less specific, ' .. + 'our distance = %s, remote distance = %s', + rule.prefix .. ':' .. set.name, + set.ann.distance, + min_diff) + end + end + else + -- We have no ANN, load new one + load_new_ann(rule, ev_base, set, sel_elt, min_diff) + end + end +end + + +-- This function checks all profiles and selects if we can train our +-- ANN. By our we mean that it has exactly the same symbols in profile. +-- Use this function to train ANN as `callback` parameter for `check_anns` function +local function maybe_train_existing_ann(worker, ev_base, rule, set, profiles) + local my_symbols = set.symbols + local sel_elt + + for _,elt in fun.iter(profiles) do + if elt and elt.symbols then + local dist = lua_util.distance_sorted(elt.symbols, my_symbols) + -- Check distance + if dist == 0 then + sel_elt = elt + break + end + end + end + + if sel_elt then + -- We have our ANN and that's train vectors, check if we can learn + local ann_key = sel_elt.redis_key + + lua_util.debugm(N, rspamd_config, "check if ANN %s needs to be trained", + ann_key) + + -- Create continuation closure + local redis_len_cb_gen = function(cont_cb, what, is_final) + return function(err, data) + if err then + rspamd_logger.errx(rspamd_config, + 'cannot get ANN %s trains %s from redis: %s', what, ann_key, err) + elseif data and type(data) == 'number' or type(data) == 'string' then + if tonumber(data) and tonumber(data) >= rule.train.max_trains then + if is_final then + rspamd_logger.debugm(N, rspamd_config, + 'can start ANN %s learn as it has %s learn vectors; %s required, after checking %s vectors', + ann_key, tonumber(data), rule.train.max_trains, what) + else + rspamd_logger.debugm(N, rspamd_config, + 'checked %s vectors in ANN %s: %s vectors; %s required, need to check other class vectors', + what, ann_key, tonumber(data), rule.train.max_trains) + end + cont_cb() + else + rspamd_logger.debugm(N, rspamd_config, + 'cannot learn ANN %s now: there are not enough %s learn vectors (has %s vectors; %s required)', + ann_key, what, tonumber(data), rule.train.max_trains) + end + end + end + + end + + local function initiate_train() + rspamd_logger.infox(rspamd_config, + 'need to learn ANN %s after %s required learn vectors', + ann_key, rule.train.max_trains) + do_train_ann(worker, ev_base, rule, set, ann_key) + end + + -- Spam vector is OK, check ham vector length + local function check_ham_len() + lua_redis.redis_make_request_taskless(ev_base, + rspamd_config, + rule.redis, + nil, + false, -- is write + redis_len_cb_gen(initiate_train, 'ham', true), --callback + 'LLEN', -- command + {ann_key .. '_ham'} + ) + end + + lua_redis.redis_make_request_taskless(ev_base, + rspamd_config, + rule.redis, + nil, + false, -- is write + redis_len_cb_gen(check_ham_len, 'spam', false), --callback + 'LLEN', -- command + {ann_key .. '_spam'} + ) + end +end + +-- Used to deserialise ANN element from a list +local function load_ann_profile(element) + local ucl = require "ucl" + + local parser = ucl.parser() + local res,ucl_err = parser:parse_string(element) + if not res then + rspamd_logger.warnx(rspamd_config, 'cannot parse ANN from redis: %s', + ucl_err) + return nil + else + local profile = parser:get_object() + local checked,schema_err = redis_profile_schema:transform(profile) + if not checked then + rspamd_logger.errx(rspamd_config, "cannot parse profile schema: %s", schema_err) + + return nil + end + return checked + end +end + +-- Function to check or load ANNs from Redis +local function check_anns(worker, cfg, ev_base, rule, process_callback, what) + for _,set in pairs(rule.settings) do + local function members_cb(err, data) + if err then + rspamd_logger.errx(cfg, 'cannot get ANNs list from redis: %s', + err) + set.can_store_vectors = true + elseif type(data) == 'table' then + lua_util.debugm(N, cfg, '%s: process element %s:%s', + what, rule.prefix, set.name) + process_callback(worker, ev_base, rule, set, fun.map(load_ann_profile, data)) + set.can_store_vectors = true + end + end + + if type(set) == 'table' then + -- Extract all profiles for some specific settings id + -- Get the last `max_profiles` recently used + -- Select the most appropriate to our profile but it should not differ by more + -- than 30% of symbols + lua_redis.redis_make_request_taskless(ev_base, + cfg, + rule.redis, + nil, + false, -- is write + members_cb, --callback + 'ZREVRANGE', -- command + {set.prefix, '0', tostring(settings.max_profiles)} -- arguments + ) + end + end -- Cycle over all settings + + return rule.watch_interval +end + +-- Function to clean up old ANNs +local function cleanup_anns(rule, cfg, ev_base) + for _,set in pairs(rule.settings) do + local function invalidate_cb(err, data) + if err then + rspamd_logger.errx(cfg, 'cannot exec invalidate script in redis: %s', + err) + elseif type(data) == 'table' then + for _,expired in ipairs(data) do + local profile = load_ann_profile(expired) + rspamd_logger.infox(cfg, 'invalidated ANN for %s; redis key: %s; version=%s', + rule.prefix .. ':' .. set.name, + profile.redis_key, + profile.version) + end + end + end + + if type(set) == 'table' then + lua_redis.exec_redis_script(redis_maybe_invalidate_id, + {ev_base = ev_base, is_write = true}, + invalidate_cb, + {set.prefix, tostring(settings.max_profiles)}) + end + end +end + +local function ann_push_vector(task) + if task:has_flag('skip') then + lua_util.debugm(N, task, 'do not push data for skipped task') + return + end + if not settings.allow_local and lua_util.is_rspamc_or_controller(task) then + lua_util.debugm(N, task, 'do not push data for manual scan') + return + end + + local verdict,score = lua_verdict.get_specific_verdict(N, task) + + if verdict == 'passthrough' then + lua_util.debugm(N, task, 'ignore task as its verdict is %s(%s)', + verdict, score) + + return + end + + if score ~= score then + lua_util.debugm(N, task, 'ignore task as its score is nan (%s verdict)', + verdict) + + return + end + + for _,rule in pairs(settings.rules) do + local set = get_rule_settings(task, rule) + + if set then + ann_push_task_result(rule, task, verdict, score, set) + else + lua_util.debugm(N, task, 'settings not found in rule %s', rule.prefix) + end + + end +end + + +-- This function is used to adjust profiles and allowed setting ids for each rule +-- It must be called when all settings are already registered (e.g. at post-init for config) +local function process_rules_settings() + local function process_settings_elt(rule, selt) + local profile = rule.profile[selt.name] + if profile then + -- Use static user defined profile + -- Ensure that we have an array... + lua_util.debugm(N, rspamd_config, "use static profile for %s (%s): %s", + rule.prefix, selt.name, profile) + if not profile[1] then profile = lua_util.keys(profile) end + selt.symbols = profile + else + lua_util.debugm(N, rspamd_config, "use dynamic cfg based profile for %s (%s)", + rule.prefix, selt.name) + end + + local function filter_symbols_predicate(sname) + local fl = rspamd_config:get_symbol_flags(sname) + if fl then + fl = lua_util.list_to_hash(fl) + + return not (fl.nostat or fl.idempotent or fl.skip) + end + + return false + end + + -- Generic stuff + table.sort(fun.totable(fun.filter(filter_symbols_predicate, selt.symbols))) + + selt.digest = lua_util.table_digest(selt.symbols) + selt.prefix = redis_ann_prefix(rule, selt.name) + + lua_redis.register_prefix(selt.prefix, N, + string.format('NN prefix for rule "%s"; settings id "%s"', + rule.prefix, selt.name), { + persistent = true, + type = 'zlist', + }) + -- Versions + lua_redis.register_prefix(selt.prefix .. '_\\d+', N, + string.format('NN storage for rule "%s"; settings id "%s"', + rule.prefix, selt.name), { + persistent = true, + type = 'hash', + }) + lua_redis.register_prefix(selt.prefix .. '_\\d+_spam', N, + string.format('NN learning set (spam) for rule "%s"; settings id "%s"', + rule.prefix, selt.name), { + persistent = true, + type = 'list', + }) + lua_redis.register_prefix(selt.prefix .. '_\\d+_ham', N, + string.format('NN learning set (spam) for rule "%s"; settings id "%s"', + rule.prefix, selt.name), { + persistent = true, + type = 'list', + }) + end + + for k,rule in pairs(settings.rules) do + if not rule.allowed_settings then + rule.allowed_settings = {} + elseif rule.allowed_settings == 'all' then + -- Extract all settings ids + rule.allowed_settings = lua_util.keys(lua_settings.all_settings()) + end + + -- Convert to a map -> true + rule.allowed_settings = lua_util.list_to_hash(rule.allowed_settings) + + -- Check if we can work without settings + if k == 'default' or type(rule.default) ~= 'boolean' then + rule.default = true + end + + rule.settings = {} + + if rule.default then + local default_settings = { + symbols = lua_settings.default_symbols(), + name = 'default' + } + + process_settings_elt(rule, default_settings) + rule.settings[-1] = default_settings -- Magic constant, but OK as settings are positive int32 + end + + -- Now, for each allowed settings, we store sorted symbols + digest + -- We set table rule.settings[id] -> { name = name, symbols = symbols, digest = digest } + for s,_ in pairs(rule.allowed_settings) do + -- Here, we have a name, set of symbols and + local settings_id = s + if type(settings_id) ~= 'number' then + settings_id = lua_settings.numeric_settings_id(s) + end + local selt = lua_settings.settings_by_id(settings_id) + + local nelt = { + symbols = selt.symbols, -- Already sorted + name = selt.name + } + + process_settings_elt(rule, nelt) + for id,ex in pairs(rule.settings) do + if type(ex) == 'table' then + if nelt and lua_util.distance_sorted(ex.symbols, nelt.symbols) == 0 then + -- Equal symbols, add reference + lua_util.debugm(N, rspamd_config, + 'added reference from settings id %s to %s; same symbols', + nelt.name, ex.name) + rule.settings[settings_id] = id + nelt = nil + end + end + end + + if nelt then + rule.settings[settings_id] = nelt + lua_util.debugm(N, rspamd_config, 'added new settings id %s(%s) to %s', + nelt.name, settings_id, rule.prefix) + end + end + end +end + +redis_params = lua_redis.parse_redis_server('neural') + +if not redis_params then + redis_params = lua_redis.parse_redis_server('fann_redis') +end + +-- Initialization part +if not (module_config and type(module_config) == 'table') or not redis_params then + rspamd_logger.infox(rspamd_config, 'Module is unconfigured') + lua_util.disable_module(N, "redis") + return +end + +local rules = module_config['rules'] + +if not rules then + -- Use legacy configuration + rules = {} + rules['default'] = module_config +end + +local id = rspamd_config:register_symbol({ + name = 'NEURAL_CHECK', + type = 'postfilter,nostat', + priority = 6, + callback = ann_scores_filter +}) + +settings = lua_util.override_defaults(settings, module_config) +settings.rules = {} -- Reset unless validated further in the cycle + +-- Check all rules +for k,r in pairs(rules) do + local rule_elt = lua_util.override_defaults(default_options, r) + rule_elt['redis'] = redis_params + rule_elt['anns'] = {} -- Store ANNs here + + if not rule_elt.prefix then + rule_elt.prefix = k + end + if not rule_elt.name then + rule_elt.name = k + end + if rule_elt.train.max_train then + rule_elt.train.max_trains = rule_elt.train.max_train + end + + if not rule_elt.profile then rule_elt.profile = {} end + + rspamd_logger.infox(rspamd_config, "register ann rule %s", k) + settings.rules[k] = rule_elt + rspamd_config:set_metric_symbol({ + name = rule_elt.symbol_spam, + score = 0.0, + description = 'Neural network SPAM', + group = 'neural' + }) + rspamd_config:register_symbol({ + name = rule_elt.symbol_spam, + type = 'virtual,nostat', + parent = id + }) + + rspamd_config:set_metric_symbol({ + name = rule_elt.symbol_ham, + score = -0.0, + description = 'Neural network HAM', + group = 'neural' + }) + rspamd_config:register_symbol({ + name = rule_elt.symbol_ham, + type = 'virtual,nostat', + parent = id + }) +end + +rspamd_config:register_symbol({ + name = 'NEURAL_LEARN', + type = 'idempotent,nostat,explicit_disable', + priority = 5, + callback = ann_push_vector +}) + +-- Add training scripts +for _,rule in pairs(settings.rules) do + load_scripts(rule.redis) + -- We also need to deal with settings + rspamd_config:add_post_init(process_rules_settings) + -- This function will check ANNs in Redis when a worker is loaded + rspamd_config:add_on_load(function(cfg, ev_base, worker) + if worker:is_scanner() then + rspamd_config:add_periodic(ev_base, 0.0, + function(_, _) + return check_anns(worker, cfg, ev_base, rule, process_existing_ann, + 'try_load_ann') + end) + end + + if worker:is_primary_controller() then + -- We also want to train neural nets when they have enough data + rspamd_config:add_periodic(ev_base, 0.0, + function(_, _) + -- Clean old ANNs + cleanup_anns(rule, cfg, ev_base) + return check_anns(worker, cfg, ev_base, rule, maybe_train_existing_ann, + 'try_train_ann') + end) + end + end) +end diff --git a/docker-compose.yml b/docker-compose.yml index bbe7347d..0024c0b9 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -69,7 +69,7 @@ services: - clamd rspamd-mailcow: - image: mailcow/rspamd:1.62 + image: mailcow/rspamd:1.64 stop_grace_period: 30s depends_on: - nginx-mailcow