536 lines
15 KiB
Go
536 lines
15 KiB
Go
package cli
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"flag"
|
|
"fmt"
|
|
"io"
|
|
"log/slog"
|
|
"os"
|
|
"time"
|
|
|
|
"github.com/bwmarrin/discordgo"
|
|
"github.com/openclaw/crawlkit/embed"
|
|
"github.com/openclaw/discrawl/internal/config"
|
|
"github.com/openclaw/discrawl/internal/discord"
|
|
"github.com/openclaw/discrawl/internal/share"
|
|
"github.com/openclaw/discrawl/internal/store"
|
|
"github.com/openclaw/discrawl/internal/syncer"
|
|
)
|
|
|
|
type cliError struct {
|
|
code int
|
|
err error
|
|
}
|
|
|
|
func (e *cliError) Error() string {
|
|
return e.err.Error()
|
|
}
|
|
|
|
func (e *cliError) Unwrap() error {
|
|
return e.err
|
|
}
|
|
|
|
func ExitCode(err error) int {
|
|
if err == nil {
|
|
return 0
|
|
}
|
|
if errors.Is(err, context.Canceled) {
|
|
return 1
|
|
}
|
|
var codeErr *cliError
|
|
if errors.As(err, &codeErr) {
|
|
return codeErr.code
|
|
}
|
|
return 1
|
|
}
|
|
|
|
func Run(ctx context.Context, args []string, stdout, stderr io.Writer) error {
|
|
if len(args) == 0 || args[0] == "help" || args[0] == "--help" || args[0] == "-h" {
|
|
printUsage(stdout)
|
|
return nil
|
|
}
|
|
global := flag.NewFlagSet("discrawl", flag.ContinueOnError)
|
|
global.SetOutput(io.Discard)
|
|
configPath := global.String("config", "", "")
|
|
jsonOut := global.Bool("json", false, "")
|
|
plainOut := global.Bool("plain", false, "")
|
|
quiet := global.Bool("quiet", false, "")
|
|
global.BoolVar(quiet, "q", false, "")
|
|
verbose := global.Bool("verbose", false, "")
|
|
global.BoolVar(verbose, "v", false, "")
|
|
versionFlag := global.Bool("version", false, "")
|
|
global.Bool("no-color", false, "")
|
|
if err := global.Parse(args); err != nil {
|
|
return usageErr(err)
|
|
}
|
|
if *versionFlag {
|
|
_, _ = io.WriteString(stdout, version+"\n")
|
|
return nil
|
|
}
|
|
rest := global.Args()
|
|
if len(rest) == 0 || rest[0] == "help" || rest[0] == "--help" || rest[0] == "-h" {
|
|
printUsage(stdout)
|
|
return nil
|
|
}
|
|
if rest[0] == "version" {
|
|
_, _ = io.WriteString(stdout, version+"\n")
|
|
return nil
|
|
}
|
|
level := slog.LevelInfo
|
|
if *quiet {
|
|
level = slog.LevelError
|
|
}
|
|
if *verbose {
|
|
level = slog.LevelDebug
|
|
}
|
|
runtime := &runtime{
|
|
ctx: ctx,
|
|
configPath: config.ResolvePath(*configPath),
|
|
stdout: stdout,
|
|
stderr: stderr,
|
|
json: *jsonOut,
|
|
plain: *plainOut,
|
|
logger: slog.New(slog.NewTextHandler(stderr, &slog.HandlerOptions{Level: level})),
|
|
}
|
|
return runtime.dispatch(rest)
|
|
}
|
|
|
|
type runtime struct {
|
|
ctx context.Context
|
|
configPath string
|
|
cfg config.Config
|
|
stdout io.Writer
|
|
stderr io.Writer
|
|
json bool
|
|
plain bool
|
|
logger *slog.Logger
|
|
store *store.Store
|
|
client discordClient
|
|
syncer syncService
|
|
dbLockHeld bool
|
|
lockStarted time.Time
|
|
openStore func(context.Context, string) (*store.Store, error)
|
|
newDiscord func(config.Config) (discordClient, error)
|
|
newSyncer func(syncer.Client, *store.Store, *slog.Logger) syncService
|
|
newEmbed func(config.EmbeddingsConfig) (embed.Provider, error)
|
|
now func() time.Time
|
|
}
|
|
|
|
func crawlkitEmbeddingConfig(cfg config.EmbeddingsConfig) embed.Config {
|
|
return embed.Config{
|
|
Provider: cfg.Provider,
|
|
Model: cfg.Model,
|
|
BaseURL: cfg.BaseURL,
|
|
APIKeyEnv: cfg.APIKeyEnv,
|
|
RequestTimeout: cfg.RequestTimeout,
|
|
MaxInputChars: cfg.MaxInputChars,
|
|
}
|
|
}
|
|
|
|
type discordClient interface {
|
|
syncer.Client
|
|
Close() error
|
|
Self(context.Context) (*discordgo.User, error)
|
|
Guilds(context.Context) ([]*discordgo.UserGuild, error)
|
|
}
|
|
|
|
type syncService interface {
|
|
DiscoverGuilds(context.Context) ([]*discordgo.UserGuild, error)
|
|
Sync(context.Context, syncer.SyncOptions) (syncer.SyncStats, error)
|
|
RunTail(context.Context, []string, time.Duration) error
|
|
}
|
|
|
|
type attachmentTextConfigurer interface {
|
|
SetAttachmentTextEnabled(bool)
|
|
}
|
|
|
|
func (r *runtime) dispatch(rest []string) error {
|
|
switch rest[0] {
|
|
case "metadata":
|
|
return r.runMetadata(rest[1:])
|
|
case "init":
|
|
return r.runInit(rest[1:])
|
|
case "sync":
|
|
updateMode, err := syncShareUpdateMode(rest[1:])
|
|
if err != nil {
|
|
return usageErr(err)
|
|
}
|
|
return r.withLocalStoreUpdateLocked(updateMode, true, func() error { return r.runSync(rest[1:]) })
|
|
case "tail":
|
|
return r.withServicesLocked(true, func() error { return r.runTail(rest[1:]) })
|
|
case "wiretap":
|
|
return r.withLocalStoreLocked(false, func() error { return r.runWiretap(rest[1:]) })
|
|
case "tap", "cache-import":
|
|
return r.withLocalStoreLocked(false, func() error { return r.runWiretap(rest[1:]) })
|
|
case "search":
|
|
autoShareUpdate := !hasBoolFlag(rest[1:], "--dm")
|
|
return r.withLocalStoreRead(autoShareUpdate, func() error { return r.runSearch(rest[1:]) })
|
|
case "tui":
|
|
if hasHelpArg(rest[1:]) {
|
|
return r.runTUI(rest[1:])
|
|
}
|
|
return r.withLocalStoreReadOnly(func() error { return r.runTUI(rest[1:]) })
|
|
case "messages":
|
|
if hasBoolFlag(rest[1:], "--sync") && !hasBoolFlag(rest[1:], "--dm") {
|
|
return r.withServicesAutoLocked(true, true, true, func() error { return r.runMessages(rest[1:]) })
|
|
}
|
|
autoShareUpdate := !hasBoolFlag(rest[1:], "--dm")
|
|
return r.withLocalStoreRead(autoShareUpdate, func() error { return r.runMessages(rest[1:]) })
|
|
case "digest":
|
|
return r.withLocalStoreRead(true, func() error { return r.runDigest(rest[1:]) })
|
|
case "analytics":
|
|
return r.runAnalytics(rest[1:])
|
|
case "dms":
|
|
return r.withLocalStoreRead(false, func() error { return r.runDirectMessages(rest[1:]) })
|
|
case "mentions":
|
|
return r.withLocalStoreRead(true, func() error { return r.runMentions(rest[1:]) })
|
|
case "embed":
|
|
return r.withLocalStoreLocked(true, func() error { return r.runEmbed(rest[1:]) })
|
|
case "sql":
|
|
if boolFlagEnabled(rest[1:], "--unsafe") {
|
|
return r.withLocalStoreLocked(true, func() error { return r.runSQL(rest[1:]) })
|
|
}
|
|
return r.withLocalStoreRead(true, func() error { return r.runSQL(rest[1:]) })
|
|
case "members":
|
|
return r.withLocalStoreRead(true, func() error { return r.runMembers(rest[1:]) })
|
|
case "channels":
|
|
return r.withLocalStoreRead(true, func() error { return r.runChannels(rest[1:]) })
|
|
case "status":
|
|
return r.withLocalStoreReadOnly(func() error { return r.runStatus(rest[1:]) })
|
|
case "report":
|
|
return r.withLocalStoreRead(true, func() error { return r.runReport(rest[1:]) })
|
|
case "publish":
|
|
return r.withServicesAutoLocked(false, false, true, func() error { return r.runPublish(rest[1:]) })
|
|
case "subscribe":
|
|
return r.runSubscribe(rest[1:])
|
|
case "update":
|
|
return r.withServicesAutoLocked(false, false, true, func() error { return r.runUpdate(rest[1:]) })
|
|
case "doctor":
|
|
return r.runDoctor(rest[1:])
|
|
default:
|
|
return usageErr(fmt.Errorf("unknown command %q", rest[0]))
|
|
}
|
|
}
|
|
|
|
func (r *runtime) withServices(withDiscord bool, fn func() error) error {
|
|
return r.withServicesAuto(withDiscord, !withDiscord, fn)
|
|
}
|
|
|
|
func (r *runtime) withServicesLocked(withDiscord bool, fn func() error) error {
|
|
return r.withServicesAutoLocked(withDiscord, !withDiscord, true, fn)
|
|
}
|
|
|
|
func (r *runtime) withLocalStoreLocked(autoShareUpdate bool, fn func() error) error {
|
|
return r.withLocalStoreUpdateLocked(boolShareUpdateMode(autoShareUpdate), true, fn)
|
|
}
|
|
|
|
func (r *runtime) withLocalStoreRead(autoShareUpdate bool, fn func() error) error {
|
|
return r.withLocalStoreReadUpdate(boolShareUpdateMode(autoShareUpdate), fn)
|
|
}
|
|
|
|
func (r *runtime) withLocalStoreReadUpdate(updateMode shareUpdateMode, fn func() error) error {
|
|
cfg, err := config.Load(r.configPath)
|
|
if err != nil {
|
|
if !errors.Is(err, os.ErrNotExist) {
|
|
return configErr(err)
|
|
}
|
|
cfg = config.Default()
|
|
if err := cfg.Normalize(); err != nil {
|
|
return configErr(err)
|
|
}
|
|
}
|
|
if err := config.EnsureRuntimeDirs(cfg); err != nil {
|
|
return configErr(err)
|
|
}
|
|
dbPath, err := config.ExpandPath(cfg.DBPath)
|
|
if err != nil {
|
|
return configErr(err)
|
|
}
|
|
r.cfg = cfg
|
|
if r.shouldAutoUpdateShare(updateMode) {
|
|
if err := r.autoUpdateShareIfLockAvailable(dbPath, updateMode); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return r.openLocalStoreReadOnly(dbPath, fn)
|
|
}
|
|
|
|
func (r *runtime) withLocalStoreUpdateLocked(updateMode shareUpdateMode, lockDB bool, fn func() error) error {
|
|
cfg, err := config.Load(r.configPath)
|
|
if err != nil {
|
|
if !errors.Is(err, os.ErrNotExist) {
|
|
return configErr(err)
|
|
}
|
|
cfg = config.Default()
|
|
if err := cfg.Normalize(); err != nil {
|
|
return configErr(err)
|
|
}
|
|
}
|
|
if err := config.EnsureRuntimeDirs(cfg); err != nil {
|
|
return configErr(err)
|
|
}
|
|
dbPath, err := config.ExpandPath(cfg.DBPath)
|
|
if err != nil {
|
|
return configErr(err)
|
|
}
|
|
r.cfg = cfg
|
|
if lockDB {
|
|
return r.withSyncLock(func() error {
|
|
return r.openLocalStore(dbPath, updateMode, fn)
|
|
})
|
|
}
|
|
return r.openLocalStore(dbPath, updateMode, fn)
|
|
}
|
|
|
|
func (r *runtime) shouldAutoUpdateShare(mode shareUpdateMode) bool {
|
|
return os.Getenv("DISCRAWL_NO_AUTO_UPDATE") != "1" &&
|
|
r.cfg.ShareEnabled() &&
|
|
(mode == shareUpdateForce || mode == shareUpdateAuto || (mode == shareUpdateConfigured && r.cfg.Share.AutoUpdate))
|
|
}
|
|
|
|
func (r *runtime) autoUpdateShareIfLockAvailable(dbPath string, updateMode shareUpdateMode) error {
|
|
locked, err := r.tryWithSyncLock(func() error {
|
|
storeFactory := r.openStore
|
|
if storeFactory == nil {
|
|
storeFactory = store.Open
|
|
}
|
|
var openErr error
|
|
r.store, openErr = storeFactory(r.ctx, dbPath)
|
|
if openErr != nil {
|
|
return dbErr(openErr)
|
|
}
|
|
defer func() {
|
|
_ = r.store.Close()
|
|
r.store = nil
|
|
}()
|
|
return r.autoUpdateShare(updateMode)
|
|
})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if !locked {
|
|
r.logger.Info("share update skipped; sync lock is held")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (r *runtime) openLocalStore(dbPath string, updateMode shareUpdateMode, fn func() error) error {
|
|
storeFactory := r.openStore
|
|
if storeFactory == nil {
|
|
storeFactory = store.Open
|
|
}
|
|
var err error
|
|
r.store, err = storeFactory(r.ctx, dbPath)
|
|
if err != nil {
|
|
return dbErr(err)
|
|
}
|
|
defer func() { _ = r.store.Close() }()
|
|
if updateMode != shareUpdateNever && os.Getenv("DISCRAWL_NO_AUTO_UPDATE") != "1" {
|
|
if err := r.autoUpdateShare(updateMode); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return fn()
|
|
}
|
|
|
|
func (r *runtime) withLocalStoreReadOnly(fn func() error) error {
|
|
cfg, err := config.Load(r.configPath)
|
|
if err != nil {
|
|
if !errors.Is(err, os.ErrNotExist) {
|
|
return configErr(err)
|
|
}
|
|
cfg = config.Default()
|
|
if err := cfg.Normalize(); err != nil {
|
|
return configErr(err)
|
|
}
|
|
}
|
|
dbPath, err := config.ExpandPath(cfg.DBPath)
|
|
if err != nil {
|
|
return configErr(err)
|
|
}
|
|
r.cfg = cfg
|
|
var openErr error
|
|
r.store, openErr = store.OpenReadOnly(r.ctx, dbPath)
|
|
if openErr != nil {
|
|
if errors.Is(openErr, os.ErrNotExist) {
|
|
r.store = nil
|
|
return fn()
|
|
}
|
|
return dbErr(openErr)
|
|
}
|
|
defer func() { _ = r.store.Close() }()
|
|
return fn()
|
|
}
|
|
|
|
func (r *runtime) openLocalStoreReadOnly(dbPath string, fn func() error) error {
|
|
var openErr error
|
|
r.store, openErr = store.OpenReadOnly(r.ctx, dbPath)
|
|
if openErr != nil {
|
|
if errors.Is(openErr, os.ErrNotExist) {
|
|
storeFactory := r.openStore
|
|
if storeFactory == nil {
|
|
storeFactory = store.Open
|
|
}
|
|
r.store, openErr = storeFactory(r.ctx, dbPath)
|
|
if openErr == nil {
|
|
defer func() { _ = r.store.Close() }()
|
|
return fn()
|
|
}
|
|
}
|
|
if errors.Is(openErr, store.ErrSchemaVersionMismatch) {
|
|
if err := r.withSyncLock(func() error {
|
|
storeFactory := r.openStore
|
|
if storeFactory == nil {
|
|
storeFactory = store.Open
|
|
}
|
|
var migrateErr error
|
|
r.store, migrateErr = storeFactory(r.ctx, dbPath)
|
|
if migrateErr != nil {
|
|
return dbErr(migrateErr)
|
|
}
|
|
closeErr := r.store.Close()
|
|
r.store = nil
|
|
return closeErr
|
|
}); err != nil {
|
|
return err
|
|
}
|
|
r.store, openErr = store.OpenReadOnly(r.ctx, dbPath)
|
|
if openErr == nil {
|
|
defer func() { _ = r.store.Close() }()
|
|
return fn()
|
|
}
|
|
}
|
|
return dbErr(openErr)
|
|
}
|
|
defer func() { _ = r.store.Close() }()
|
|
return fn()
|
|
}
|
|
|
|
func (r *runtime) withServicesAuto(withDiscord, autoShareUpdate bool, fn func() error) error {
|
|
return r.withServicesAutoLocked(withDiscord, autoShareUpdate, false, fn)
|
|
}
|
|
|
|
func (r *runtime) withServicesAutoLocked(withDiscord, autoShareUpdate, lockDB bool, fn func() error) error {
|
|
return r.withServicesUpdateLocked(withDiscord, boolShareUpdateMode(autoShareUpdate), lockDB, fn)
|
|
}
|
|
|
|
func (r *runtime) withServicesUpdateLocked(withDiscord bool, updateMode shareUpdateMode, lockDB bool, fn func() error) error {
|
|
cfg, err := config.Load(r.configPath)
|
|
if err != nil {
|
|
return configErr(err)
|
|
}
|
|
if err := config.EnsureRuntimeDirs(cfg); err != nil {
|
|
return configErr(err)
|
|
}
|
|
dbPath, err := config.ExpandPath(cfg.DBPath)
|
|
if err != nil {
|
|
return configErr(err)
|
|
}
|
|
r.cfg = cfg
|
|
if lockDB {
|
|
return r.withSyncLock(func() error {
|
|
return r.openServices(dbPath, withDiscord, updateMode, fn)
|
|
})
|
|
}
|
|
return r.openServices(dbPath, withDiscord, updateMode, fn)
|
|
}
|
|
|
|
func (r *runtime) openServices(dbPath string, withDiscord bool, updateMode shareUpdateMode, fn func() error) error {
|
|
storeFactory := r.openStore
|
|
if storeFactory == nil {
|
|
storeFactory = store.Open
|
|
}
|
|
var err error
|
|
r.store, err = storeFactory(r.ctx, dbPath)
|
|
if err != nil {
|
|
return dbErr(err)
|
|
}
|
|
defer func() { _ = r.store.Close() }()
|
|
if updateMode != shareUpdateNever && os.Getenv("DISCRAWL_NO_AUTO_UPDATE") != "1" {
|
|
if err := r.autoUpdateShare(updateMode); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
if withDiscord {
|
|
if err := r.ensureDiscordServices(); err != nil {
|
|
return err
|
|
}
|
|
if r.client != nil {
|
|
defer func() { _ = r.client.Close() }()
|
|
}
|
|
}
|
|
return fn()
|
|
}
|
|
|
|
func (r *runtime) ensureDiscordServices() error {
|
|
discordFactory := r.newDiscord
|
|
if discordFactory == nil {
|
|
discordFactory = func(cfg config.Config) (discordClient, error) {
|
|
token, err := config.ResolveDiscordToken(cfg)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return discord.New(token.Token)
|
|
}
|
|
}
|
|
client, err := discordFactory(r.cfg)
|
|
if err != nil {
|
|
return authErr(err)
|
|
}
|
|
r.client = client
|
|
syncerFactory := r.newSyncer
|
|
if syncerFactory == nil {
|
|
syncerFactory = func(client syncer.Client, s *store.Store, logger *slog.Logger) syncService {
|
|
return syncer.New(client, s, logger)
|
|
}
|
|
}
|
|
r.syncer = syncerFactory(r.client, r.store, r.logger)
|
|
if configurable, ok := r.syncer.(attachmentTextConfigurer); ok {
|
|
configurable.SetAttachmentTextEnabled(r.cfg.AttachmentTextEnabled())
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (r *runtime) autoUpdateShare(mode shareUpdateMode) error {
|
|
if !r.cfg.ShareEnabled() || (mode == shareUpdateConfigured && !r.cfg.Share.AutoUpdate) {
|
|
return nil
|
|
}
|
|
staleAfter, err := time.ParseDuration(r.cfg.Share.StaleAfter)
|
|
if err != nil {
|
|
return configErr(fmt.Errorf("invalid share.stale_after: %w", err))
|
|
}
|
|
if mode != shareUpdateForce && !share.NeedsImport(r.ctx, r.store, staleAfter) {
|
|
return nil
|
|
}
|
|
opts, err := r.shareOptions()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
r.setSyncLockPhase("share pull")
|
|
r.logger.Info("share update pulling", "repo_path", opts.RepoPath, "remote", opts.Remote)
|
|
if err := share.Pull(r.ctx, opts); err != nil {
|
|
return err
|
|
}
|
|
r.setSyncLockPhase("share import")
|
|
_, _, err = share.ImportIfChanged(r.ctx, r.store, opts)
|
|
if errors.Is(err, share.ErrNoManifest) {
|
|
return nil
|
|
}
|
|
return err
|
|
}
|
|
|
|
func (r *runtime) shareOptions() (share.Options, error) {
|
|
repoPath, err := config.ExpandPath(r.cfg.Share.RepoPath)
|
|
if err != nil {
|
|
return share.Options{}, configErr(err)
|
|
}
|
|
return share.Options{
|
|
RepoPath: repoPath,
|
|
Remote: r.cfg.Share.Remote,
|
|
Branch: r.cfg.Share.Branch,
|
|
Progress: r.shareProgress,
|
|
}, nil
|
|
}
|