{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE QuasiQuotes #-}

module Simplex.Chat.Store.SQLite.Migrations.M20251117_member_relations_vector where

import qualified Data.ByteString as B
import Database.SQLite.Simple (Query)
import Database.SQLite.Simple.QQ (sql)
import Database.SQLite3 (funcArgBlob, funcArgInt64, funcArgText, funcResultBlob)
import Database.SQLite3.Bindings
import Foreign.C.Types
import Foreign.Ptr
import Simplex.Chat.Types.MemberRelations (IntroductionDirection (..), MemberRelation (..), fromIntroDirInt, fromRelationInt, setNewRelation, setNewRelations)
import Simplex.Messaging.Agent.Store.SQLite.Util (SQLiteFunc, SQLiteFuncFinal, mkSQLiteAggFinal, mkSQLiteAggStep, mkSQLiteFunc)

-- This module defines custom aggregate function migrate_relations_vector(idx, direction, intro_status).
-- It is passed via DBOpts and registered on DB open.
-- Used in live migration and stage 2 migration (M20251128_migrate_member_relations).
--
-- Vector byte encoding: 4 reserved | 1 direction | 3 status
-- Direction: 0 = IDSubjectIntroduced, 1 = IDReferencedIntroduced
-- Status values: 0 = MRNew, 1 = MRIntroduced, 2 = MRSubjectConnected, 3 = MRReferencedConnected, 4 = MRConnected
--
-- The aggregate transforms intro_status into relation status:
-- - intro_status 'new'/'sent'/'rcv'/'fwd': MRIntroduced (1)
-- - intro_status 're-con': if direction=0 then MRSubjectConnected (2), else MRReferencedConnected (3)
-- - intro_status 'to-con': if direction=0 then MRReferencedConnected (3), else MRSubjectConnected (2)
-- - intro_status 'con': MRConnected (4)
--
-- The final function builds the vector using setNewRelations.

foreign export ccall "simplex_member_relations_step" sqliteMemberRelationsStep :: SQLiteFunc

foreign import ccall "&simplex_member_relations_step" sqliteMemberRelationsStepPtr :: FunPtr SQLiteFunc

foreign export ccall "simplex_member_relations_final" sqliteMemberRelationsFinal :: SQLiteFuncFinal

foreign import ccall "&simplex_member_relations_final" sqliteMemberRelationsFinalPtr :: FunPtr SQLiteFuncFinal

-- Step function for migrate_relations_vector aggregate.
-- Accumulates (idx, direction, relation) tuples.
sqliteMemberRelationsStep :: SQLiteFunc
sqliteMemberRelationsStep :: SQLiteFunc
sqliteMemberRelationsStep = [(Int64, (IntroductionDirection, MemberRelation))]
-> (FuncContext
    -> FuncArgs
    -> [(Int64, (IntroductionDirection, MemberRelation))]
    -> IO [(Int64, (IntroductionDirection, MemberRelation))])
-> SQLiteFunc
forall a. a -> (FuncContext -> FuncArgs -> a -> IO a) -> SQLiteFunc
mkSQLiteAggStep [] ((FuncContext
  -> FuncArgs
  -> [(Int64, (IntroductionDirection, MemberRelation))]
  -> IO [(Int64, (IntroductionDirection, MemberRelation))])
 -> SQLiteFunc)
-> (FuncContext
    -> FuncArgs
    -> [(Int64, (IntroductionDirection, MemberRelation))]
    -> IO [(Int64, (IntroductionDirection, MemberRelation))])
-> SQLiteFunc
forall a b. (a -> b) -> a -> b
$ \FuncContext
_ FuncArgs
args [(Int64, (IntroductionDirection, MemberRelation))]
acc -> do
  Int64
idx <- FuncArgs -> ArgIndex -> IO Int64
funcArgInt64 FuncArgs
args ArgIndex
0
  IntroductionDirection
direction <- Word8 -> IntroductionDirection
fromIntroDirInt (Word8 -> IntroductionDirection)
-> (Int64 -> Word8) -> Int64 -> IntroductionDirection
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int64 -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int64 -> IntroductionDirection)
-> IO Int64 -> IO IntroductionDirection
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> FuncArgs -> ArgIndex -> IO Int64
funcArgInt64 FuncArgs
args ArgIndex
1
  Text
introStatus <- FuncArgs -> ArgIndex -> IO Text
funcArgText FuncArgs
args ArgIndex
2
  let relation :: MemberRelation
relation = IntroductionDirection -> Text -> MemberRelation
forall {a}.
(Eq a, IsString a) =>
IntroductionDirection -> a -> MemberRelation
introStatusToRelation IntroductionDirection
direction Text
introStatus
  [(Int64, (IntroductionDirection, MemberRelation))]
-> IO [(Int64, (IntroductionDirection, MemberRelation))]
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([(Int64, (IntroductionDirection, MemberRelation))]
 -> IO [(Int64, (IntroductionDirection, MemberRelation))])
-> [(Int64, (IntroductionDirection, MemberRelation))]
-> IO [(Int64, (IntroductionDirection, MemberRelation))]
forall a b. (a -> b) -> a -> b
$ (Int64
idx, (IntroductionDirection
direction, MemberRelation
relation)) (Int64, (IntroductionDirection, MemberRelation))
-> [(Int64, (IntroductionDirection, MemberRelation))]
-> [(Int64, (IntroductionDirection, MemberRelation))]
forall a. a -> [a] -> [a]
: [(Int64, (IntroductionDirection, MemberRelation))]
acc
  where
    introStatusToRelation :: IntroductionDirection -> a -> MemberRelation
introStatusToRelation IntroductionDirection
dir a
status = case a
status of
      a
"re-con" -> if IntroductionDirection
dir IntroductionDirection -> IntroductionDirection -> Bool
forall a. Eq a => a -> a -> Bool
== IntroductionDirection
IDSubjectIntroduced then MemberRelation
MRSubjectConnected else MemberRelation
MRReferencedConnected
      a
"to-con" -> if IntroductionDirection
dir IntroductionDirection -> IntroductionDirection -> Bool
forall a. Eq a => a -> a -> Bool
== IntroductionDirection
IDSubjectIntroduced then MemberRelation
MRReferencedConnected else MemberRelation
MRSubjectConnected
      a
"con" -> MemberRelation
MRConnected
      a
_ -> MemberRelation
MRIntroduced -- 'new', 'sent', 'rcv', 'fwd'

-- Final function for migrate_relations_vector aggregate.
-- Builds the vector from accumulated tuples using setNewRelations.
sqliteMemberRelationsFinal :: SQLiteFuncFinal
sqliteMemberRelationsFinal :: SQLiteFuncFinal
sqliteMemberRelationsFinal = [(Int64, (IntroductionDirection, MemberRelation))]
-> (FuncContext
    -> [(Int64, (IntroductionDirection, MemberRelation))] -> IO ())
-> SQLiteFuncFinal
forall a. a -> (FuncContext -> a -> IO ()) -> SQLiteFuncFinal
mkSQLiteAggFinal [] ((FuncContext
  -> [(Int64, (IntroductionDirection, MemberRelation))] -> IO ())
 -> SQLiteFuncFinal)
-> (FuncContext
    -> [(Int64, (IntroductionDirection, MemberRelation))] -> IO ())
-> SQLiteFuncFinal
forall a b. (a -> b) -> a -> b
$ \FuncContext
cxt [(Int64, (IntroductionDirection, MemberRelation))]
acc -> FuncContext -> ByteString -> IO ()
funcResultBlob FuncContext
cxt (ByteString -> IO ()) -> ByteString -> IO ()
forall a b. (a -> b) -> a -> b
$ [(Int64, (IntroductionDirection, MemberRelation))]
-> ByteString -> ByteString
setNewRelations [(Int64, (IntroductionDirection, MemberRelation))]
acc ByteString
B.empty

-- Non-aggregate function set_member_vector_new_relation(vector, idx, direction, status).
-- Sets a new relation in the vector and returns the updated vector.

foreign export ccall "simplex_set_member_vector_new_relation" sqliteSetMemberVectorNewRelation :: SQLiteFunc

foreign import ccall "&simplex_set_member_vector_new_relation" sqliteSetMemberVectorNewRelationPtr :: FunPtr SQLiteFunc

sqliteSetMemberVectorNewRelation :: SQLiteFunc
sqliteSetMemberVectorNewRelation :: SQLiteFunc
sqliteSetMemberVectorNewRelation = (FuncContext -> FuncArgs -> IO ()) -> SQLiteFunc
mkSQLiteFunc ((FuncContext -> FuncArgs -> IO ()) -> SQLiteFunc)
-> (FuncContext -> FuncArgs -> IO ()) -> SQLiteFunc
forall a b. (a -> b) -> a -> b
$ \FuncContext
cxt FuncArgs
args -> do
  ByteString
v <- FuncArgs -> ArgIndex -> IO ByteString
funcArgBlob FuncArgs
args ArgIndex
0
  Int64
idx <- FuncArgs -> ArgIndex -> IO Int64
funcArgInt64 FuncArgs
args ArgIndex
1
  IntroductionDirection
direction <- Word8 -> IntroductionDirection
fromIntroDirInt (Word8 -> IntroductionDirection)
-> (Int64 -> Word8) -> Int64 -> IntroductionDirection
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int64 -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int64 -> IntroductionDirection)
-> IO Int64 -> IO IntroductionDirection
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> FuncArgs -> ArgIndex -> IO Int64
funcArgInt64 FuncArgs
args ArgIndex
2
  MemberRelation
status <- Word8 -> MemberRelation
fromRelationInt (Word8 -> MemberRelation)
-> (Int64 -> Word8) -> Int64 -> MemberRelation
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int64 -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int64 -> MemberRelation) -> IO Int64 -> IO MemberRelation
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> FuncArgs -> ArgIndex -> IO Int64
funcArgInt64 FuncArgs
args ArgIndex
3
  FuncContext -> ByteString -> IO ()
funcResultBlob FuncContext
cxt (ByteString -> IO ()) -> ByteString -> IO ()
forall a b. (a -> b) -> a -> b
$ Int64
-> IntroductionDirection
-> MemberRelation
-> ByteString
-> ByteString
setNewRelation Int64
idx IntroductionDirection
direction MemberRelation
status ByteString
v

m20251117_member_relations_vector :: Query
m20251117_member_relations_vector :: Query
m20251117_member_relations_vector =
  [sql|
ALTER TABLE group_members ADD COLUMN index_in_group INTEGER NOT NULL DEFAULT 0;

ALTER TABLE groups ADD COLUMN member_index INTEGER NOT NULL DEFAULT 0;

ALTER TABLE group_members ADD COLUMN member_relations_vector BLOB;

CREATE INDEX tmp_idx_group_members_group_id_group_member_id ON group_members(group_id, group_member_id);

CREATE TEMPORARY TABLE tmp_members_indexed AS
SELECT
  group_member_id,
  ROW_NUMBER() OVER (
    PARTITION BY group_id
    ORDER BY group_member_id ASC
  ) - 1 AS idx_in_group
FROM group_members;

CREATE INDEX tmp_idx_members_indexed ON tmp_members_indexed(group_member_id);

UPDATE group_members AS gm
SET index_in_group = (
  SELECT idx_in_group
  FROM tmp_members_indexed
  WHERE tmp_members_indexed.group_member_id = gm.group_member_id
);

DROP INDEX tmp_idx_group_members_group_id_group_member_id;
DROP INDEX tmp_idx_members_indexed;
DROP TABLE tmp_members_indexed;

CREATE UNIQUE INDEX idx_group_members_group_id_index_in_group ON group_members(group_id, index_in_group);

UPDATE groups AS g
SET member_index = COALESCE((
  SELECT MAX(index_in_group) + 1
  FROM group_members
  WHERE group_members.group_id = g.group_id
), 0);

UPDATE group_members
SET member_relations_vector = x''
WHERE group_id IN (
  SELECT mu.group_id
  FROM group_members mu
  WHERE mu.member_category = 'user'
    AND (
      mu.member_role NOT IN (CAST('admin' AS BLOB), CAST('owner' AS BLOB))
      OR mu.member_status IN ('removed', 'left', 'deleted')
    )
);
|]

down_m20251117_member_relations_vector :: Query
down_m20251117_member_relations_vector :: Query
down_m20251117_member_relations_vector =
  [sql|
DROP INDEX idx_group_members_group_id_index_in_group;

ALTER TABLE group_members DROP COLUMN index_in_group;

ALTER TABLE groups DROP COLUMN member_index;

ALTER TABLE group_members DROP COLUMN member_relations_vector;
|]