discrawl/internal/cli/cli.go
2026-05-08 10:04:39 +01:00

536 lines
15 KiB
Go

package cli
import (
"context"
"errors"
"flag"
"fmt"
"io"
"log/slog"
"os"
"time"
"github.com/bwmarrin/discordgo"
"github.com/openclaw/crawlkit/embed"
"github.com/openclaw/discrawl/internal/config"
"github.com/openclaw/discrawl/internal/discord"
"github.com/openclaw/discrawl/internal/share"
"github.com/openclaw/discrawl/internal/store"
"github.com/openclaw/discrawl/internal/syncer"
)
type cliError struct {
code int
err error
}
func (e *cliError) Error() string {
return e.err.Error()
}
func (e *cliError) Unwrap() error {
return e.err
}
func ExitCode(err error) int {
if err == nil {
return 0
}
if errors.Is(err, context.Canceled) {
return 1
}
var codeErr *cliError
if errors.As(err, &codeErr) {
return codeErr.code
}
return 1
}
func Run(ctx context.Context, args []string, stdout, stderr io.Writer) error {
if len(args) == 0 || args[0] == "help" || args[0] == "--help" || args[0] == "-h" {
printUsage(stdout)
return nil
}
global := flag.NewFlagSet("discrawl", flag.ContinueOnError)
global.SetOutput(io.Discard)
configPath := global.String("config", "", "")
jsonOut := global.Bool("json", false, "")
plainOut := global.Bool("plain", false, "")
quiet := global.Bool("quiet", false, "")
global.BoolVar(quiet, "q", false, "")
verbose := global.Bool("verbose", false, "")
global.BoolVar(verbose, "v", false, "")
versionFlag := global.Bool("version", false, "")
global.Bool("no-color", false, "")
if err := global.Parse(args); err != nil {
return usageErr(err)
}
if *versionFlag {
_, _ = io.WriteString(stdout, version+"\n")
return nil
}
rest := global.Args()
if len(rest) == 0 || rest[0] == "help" || rest[0] == "--help" || rest[0] == "-h" {
printUsage(stdout)
return nil
}
if rest[0] == "version" {
_, _ = io.WriteString(stdout, version+"\n")
return nil
}
level := slog.LevelInfo
if *quiet {
level = slog.LevelError
}
if *verbose {
level = slog.LevelDebug
}
runtime := &runtime{
ctx: ctx,
configPath: config.ResolvePath(*configPath),
stdout: stdout,
stderr: stderr,
json: *jsonOut,
plain: *plainOut,
logger: slog.New(slog.NewTextHandler(stderr, &slog.HandlerOptions{Level: level})),
}
return runtime.dispatch(rest)
}
type runtime struct {
ctx context.Context
configPath string
cfg config.Config
stdout io.Writer
stderr io.Writer
json bool
plain bool
logger *slog.Logger
store *store.Store
client discordClient
syncer syncService
dbLockHeld bool
lockStarted time.Time
openStore func(context.Context, string) (*store.Store, error)
newDiscord func(config.Config) (discordClient, error)
newSyncer func(syncer.Client, *store.Store, *slog.Logger) syncService
newEmbed func(config.EmbeddingsConfig) (embed.Provider, error)
now func() time.Time
}
func crawlkitEmbeddingConfig(cfg config.EmbeddingsConfig) embed.Config {
return embed.Config{
Provider: cfg.Provider,
Model: cfg.Model,
BaseURL: cfg.BaseURL,
APIKeyEnv: cfg.APIKeyEnv,
RequestTimeout: cfg.RequestTimeout,
MaxInputChars: cfg.MaxInputChars,
}
}
type discordClient interface {
syncer.Client
Close() error
Self(context.Context) (*discordgo.User, error)
Guilds(context.Context) ([]*discordgo.UserGuild, error)
}
type syncService interface {
DiscoverGuilds(context.Context) ([]*discordgo.UserGuild, error)
Sync(context.Context, syncer.SyncOptions) (syncer.SyncStats, error)
RunTail(context.Context, []string, time.Duration) error
}
type attachmentTextConfigurer interface {
SetAttachmentTextEnabled(bool)
}
func (r *runtime) dispatch(rest []string) error {
switch rest[0] {
case "metadata":
return r.runMetadata(rest[1:])
case "init":
return r.runInit(rest[1:])
case "sync":
updateMode, err := syncShareUpdateMode(rest[1:])
if err != nil {
return usageErr(err)
}
return r.withLocalStoreUpdateLocked(updateMode, true, func() error { return r.runSync(rest[1:]) })
case "tail":
return r.withServicesLocked(true, func() error { return r.runTail(rest[1:]) })
case "wiretap":
return r.withLocalStoreLocked(false, func() error { return r.runWiretap(rest[1:]) })
case "tap", "cache-import":
return r.withLocalStoreLocked(false, func() error { return r.runWiretap(rest[1:]) })
case "search":
autoShareUpdate := !hasBoolFlag(rest[1:], "--dm")
return r.withLocalStoreRead(autoShareUpdate, func() error { return r.runSearch(rest[1:]) })
case "tui":
if hasHelpArg(rest[1:]) {
return r.runTUI(rest[1:])
}
return r.withLocalStoreReadOnly(func() error { return r.runTUI(rest[1:]) })
case "messages":
if hasBoolFlag(rest[1:], "--sync") && !hasBoolFlag(rest[1:], "--dm") {
return r.withServicesAutoLocked(true, true, true, func() error { return r.runMessages(rest[1:]) })
}
autoShareUpdate := !hasBoolFlag(rest[1:], "--dm")
return r.withLocalStoreRead(autoShareUpdate, func() error { return r.runMessages(rest[1:]) })
case "digest":
return r.withLocalStoreRead(true, func() error { return r.runDigest(rest[1:]) })
case "analytics":
return r.runAnalytics(rest[1:])
case "dms":
return r.withLocalStoreRead(false, func() error { return r.runDirectMessages(rest[1:]) })
case "mentions":
return r.withLocalStoreRead(true, func() error { return r.runMentions(rest[1:]) })
case "embed":
return r.withLocalStoreLocked(true, func() error { return r.runEmbed(rest[1:]) })
case "sql":
if boolFlagEnabled(rest[1:], "--unsafe") {
return r.withLocalStoreLocked(true, func() error { return r.runSQL(rest[1:]) })
}
return r.withLocalStoreRead(true, func() error { return r.runSQL(rest[1:]) })
case "members":
return r.withLocalStoreRead(true, func() error { return r.runMembers(rest[1:]) })
case "channels":
return r.withLocalStoreRead(true, func() error { return r.runChannels(rest[1:]) })
case "status":
return r.withLocalStoreReadOnly(func() error { return r.runStatus(rest[1:]) })
case "report":
return r.withLocalStoreRead(true, func() error { return r.runReport(rest[1:]) })
case "publish":
return r.withServicesAutoLocked(false, false, true, func() error { return r.runPublish(rest[1:]) })
case "subscribe":
return r.runSubscribe(rest[1:])
case "update":
return r.withServicesAutoLocked(false, false, true, func() error { return r.runUpdate(rest[1:]) })
case "doctor":
return r.runDoctor(rest[1:])
default:
return usageErr(fmt.Errorf("unknown command %q", rest[0]))
}
}
func (r *runtime) withServices(withDiscord bool, fn func() error) error {
return r.withServicesAuto(withDiscord, !withDiscord, fn)
}
func (r *runtime) withServicesLocked(withDiscord bool, fn func() error) error {
return r.withServicesAutoLocked(withDiscord, !withDiscord, true, fn)
}
func (r *runtime) withLocalStoreLocked(autoShareUpdate bool, fn func() error) error {
return r.withLocalStoreUpdateLocked(boolShareUpdateMode(autoShareUpdate), true, fn)
}
func (r *runtime) withLocalStoreRead(autoShareUpdate bool, fn func() error) error {
return r.withLocalStoreReadUpdate(boolShareUpdateMode(autoShareUpdate), fn)
}
func (r *runtime) withLocalStoreReadUpdate(updateMode shareUpdateMode, fn func() error) error {
cfg, err := config.Load(r.configPath)
if err != nil {
if !errors.Is(err, os.ErrNotExist) {
return configErr(err)
}
cfg = config.Default()
if err := cfg.Normalize(); err != nil {
return configErr(err)
}
}
if err := config.EnsureRuntimeDirs(cfg); err != nil {
return configErr(err)
}
dbPath, err := config.ExpandPath(cfg.DBPath)
if err != nil {
return configErr(err)
}
r.cfg = cfg
if r.shouldAutoUpdateShare(updateMode) {
if err := r.autoUpdateShareIfLockAvailable(dbPath, updateMode); err != nil {
return err
}
}
return r.openLocalStoreReadOnly(dbPath, fn)
}
func (r *runtime) withLocalStoreUpdateLocked(updateMode shareUpdateMode, lockDB bool, fn func() error) error {
cfg, err := config.Load(r.configPath)
if err != nil {
if !errors.Is(err, os.ErrNotExist) {
return configErr(err)
}
cfg = config.Default()
if err := cfg.Normalize(); err != nil {
return configErr(err)
}
}
if err := config.EnsureRuntimeDirs(cfg); err != nil {
return configErr(err)
}
dbPath, err := config.ExpandPath(cfg.DBPath)
if err != nil {
return configErr(err)
}
r.cfg = cfg
if lockDB {
return r.withSyncLock(func() error {
return r.openLocalStore(dbPath, updateMode, fn)
})
}
return r.openLocalStore(dbPath, updateMode, fn)
}
func (r *runtime) shouldAutoUpdateShare(mode shareUpdateMode) bool {
return os.Getenv("DISCRAWL_NO_AUTO_UPDATE") != "1" &&
r.cfg.ShareEnabled() &&
(mode == shareUpdateForce || mode == shareUpdateAuto || (mode == shareUpdateConfigured && r.cfg.Share.AutoUpdate))
}
func (r *runtime) autoUpdateShareIfLockAvailable(dbPath string, updateMode shareUpdateMode) error {
locked, err := r.tryWithSyncLock(func() error {
storeFactory := r.openStore
if storeFactory == nil {
storeFactory = store.Open
}
var openErr error
r.store, openErr = storeFactory(r.ctx, dbPath)
if openErr != nil {
return dbErr(openErr)
}
defer func() {
_ = r.store.Close()
r.store = nil
}()
return r.autoUpdateShare(updateMode)
})
if err != nil {
return err
}
if !locked {
r.logger.Info("share update skipped; sync lock is held")
}
return nil
}
func (r *runtime) openLocalStore(dbPath string, updateMode shareUpdateMode, fn func() error) error {
storeFactory := r.openStore
if storeFactory == nil {
storeFactory = store.Open
}
var err error
r.store, err = storeFactory(r.ctx, dbPath)
if err != nil {
return dbErr(err)
}
defer func() { _ = r.store.Close() }()
if updateMode != shareUpdateNever && os.Getenv("DISCRAWL_NO_AUTO_UPDATE") != "1" {
if err := r.autoUpdateShare(updateMode); err != nil {
return err
}
}
return fn()
}
func (r *runtime) withLocalStoreReadOnly(fn func() error) error {
cfg, err := config.Load(r.configPath)
if err != nil {
if !errors.Is(err, os.ErrNotExist) {
return configErr(err)
}
cfg = config.Default()
if err := cfg.Normalize(); err != nil {
return configErr(err)
}
}
dbPath, err := config.ExpandPath(cfg.DBPath)
if err != nil {
return configErr(err)
}
r.cfg = cfg
var openErr error
r.store, openErr = store.OpenReadOnly(r.ctx, dbPath)
if openErr != nil {
if errors.Is(openErr, os.ErrNotExist) {
r.store = nil
return fn()
}
return dbErr(openErr)
}
defer func() { _ = r.store.Close() }()
return fn()
}
func (r *runtime) openLocalStoreReadOnly(dbPath string, fn func() error) error {
var openErr error
r.store, openErr = store.OpenReadOnly(r.ctx, dbPath)
if openErr != nil {
if errors.Is(openErr, os.ErrNotExist) {
storeFactory := r.openStore
if storeFactory == nil {
storeFactory = store.Open
}
r.store, openErr = storeFactory(r.ctx, dbPath)
if openErr == nil {
defer func() { _ = r.store.Close() }()
return fn()
}
}
if errors.Is(openErr, store.ErrSchemaVersionMismatch) {
if err := r.withSyncLock(func() error {
storeFactory := r.openStore
if storeFactory == nil {
storeFactory = store.Open
}
var migrateErr error
r.store, migrateErr = storeFactory(r.ctx, dbPath)
if migrateErr != nil {
return dbErr(migrateErr)
}
closeErr := r.store.Close()
r.store = nil
return closeErr
}); err != nil {
return err
}
r.store, openErr = store.OpenReadOnly(r.ctx, dbPath)
if openErr == nil {
defer func() { _ = r.store.Close() }()
return fn()
}
}
return dbErr(openErr)
}
defer func() { _ = r.store.Close() }()
return fn()
}
func (r *runtime) withServicesAuto(withDiscord, autoShareUpdate bool, fn func() error) error {
return r.withServicesAutoLocked(withDiscord, autoShareUpdate, false, fn)
}
func (r *runtime) withServicesAutoLocked(withDiscord, autoShareUpdate, lockDB bool, fn func() error) error {
return r.withServicesUpdateLocked(withDiscord, boolShareUpdateMode(autoShareUpdate), lockDB, fn)
}
func (r *runtime) withServicesUpdateLocked(withDiscord bool, updateMode shareUpdateMode, lockDB bool, fn func() error) error {
cfg, err := config.Load(r.configPath)
if err != nil {
return configErr(err)
}
if err := config.EnsureRuntimeDirs(cfg); err != nil {
return configErr(err)
}
dbPath, err := config.ExpandPath(cfg.DBPath)
if err != nil {
return configErr(err)
}
r.cfg = cfg
if lockDB {
return r.withSyncLock(func() error {
return r.openServices(dbPath, withDiscord, updateMode, fn)
})
}
return r.openServices(dbPath, withDiscord, updateMode, fn)
}
func (r *runtime) openServices(dbPath string, withDiscord bool, updateMode shareUpdateMode, fn func() error) error {
storeFactory := r.openStore
if storeFactory == nil {
storeFactory = store.Open
}
var err error
r.store, err = storeFactory(r.ctx, dbPath)
if err != nil {
return dbErr(err)
}
defer func() { _ = r.store.Close() }()
if updateMode != shareUpdateNever && os.Getenv("DISCRAWL_NO_AUTO_UPDATE") != "1" {
if err := r.autoUpdateShare(updateMode); err != nil {
return err
}
}
if withDiscord {
if err := r.ensureDiscordServices(); err != nil {
return err
}
if r.client != nil {
defer func() { _ = r.client.Close() }()
}
}
return fn()
}
func (r *runtime) ensureDiscordServices() error {
discordFactory := r.newDiscord
if discordFactory == nil {
discordFactory = func(cfg config.Config) (discordClient, error) {
token, err := config.ResolveDiscordToken(cfg)
if err != nil {
return nil, err
}
return discord.New(token.Token)
}
}
client, err := discordFactory(r.cfg)
if err != nil {
return authErr(err)
}
r.client = client
syncerFactory := r.newSyncer
if syncerFactory == nil {
syncerFactory = func(client syncer.Client, s *store.Store, logger *slog.Logger) syncService {
return syncer.New(client, s, logger)
}
}
r.syncer = syncerFactory(r.client, r.store, r.logger)
if configurable, ok := r.syncer.(attachmentTextConfigurer); ok {
configurable.SetAttachmentTextEnabled(r.cfg.AttachmentTextEnabled())
}
return nil
}
func (r *runtime) autoUpdateShare(mode shareUpdateMode) error {
if !r.cfg.ShareEnabled() || (mode == shareUpdateConfigured && !r.cfg.Share.AutoUpdate) {
return nil
}
staleAfter, err := time.ParseDuration(r.cfg.Share.StaleAfter)
if err != nil {
return configErr(fmt.Errorf("invalid share.stale_after: %w", err))
}
if mode != shareUpdateForce && !share.NeedsImport(r.ctx, r.store, staleAfter) {
return nil
}
opts, err := r.shareOptions()
if err != nil {
return err
}
r.setSyncLockPhase("share pull")
r.logger.Info("share update pulling", "repo_path", opts.RepoPath, "remote", opts.Remote)
if err := share.Pull(r.ctx, opts); err != nil {
return err
}
r.setSyncLockPhase("share import")
_, _, err = share.ImportIfChanged(r.ctx, r.store, opts)
if errors.Is(err, share.ErrNoManifest) {
return nil
}
return err
}
func (r *runtime) shareOptions() (share.Options, error) {
repoPath, err := config.ExpandPath(r.cfg.Share.RepoPath)
if err != nil {
return share.Options{}, configErr(err)
}
return share.Options{
RepoPath: repoPath,
Remote: r.cfg.Share.Remote,
Branch: r.cfg.Share.Branch,
Progress: r.shareProgress,
}, nil
}