353 lines
9.9 KiB
Go
353 lines
9.9 KiB
Go
package cli
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"flag"
|
|
"fmt"
|
|
"io"
|
|
"log/slog"
|
|
"os"
|
|
"time"
|
|
|
|
"github.com/bwmarrin/discordgo"
|
|
"github.com/steipete/discrawl/internal/config"
|
|
"github.com/steipete/discrawl/internal/discord"
|
|
"github.com/steipete/discrawl/internal/embed"
|
|
"github.com/steipete/discrawl/internal/share"
|
|
"github.com/steipete/discrawl/internal/store"
|
|
"github.com/steipete/discrawl/internal/syncer"
|
|
)
|
|
|
|
type cliError struct {
|
|
code int
|
|
err error
|
|
}
|
|
|
|
func (e *cliError) Error() string {
|
|
return e.err.Error()
|
|
}
|
|
|
|
func ExitCode(err error) int {
|
|
if err == nil {
|
|
return 0
|
|
}
|
|
var codeErr *cliError
|
|
if errors.As(err, &codeErr) {
|
|
return codeErr.code
|
|
}
|
|
return 1
|
|
}
|
|
|
|
func Run(ctx context.Context, args []string, stdout, stderr io.Writer) error {
|
|
global := flag.NewFlagSet("discrawl", flag.ContinueOnError)
|
|
global.SetOutput(io.Discard)
|
|
configPath := global.String("config", "", "")
|
|
jsonOut := global.Bool("json", false, "")
|
|
plainOut := global.Bool("plain", false, "")
|
|
quiet := global.Bool("quiet", false, "")
|
|
global.BoolVar(quiet, "q", false, "")
|
|
verbose := global.Bool("verbose", false, "")
|
|
global.BoolVar(verbose, "v", false, "")
|
|
versionFlag := global.Bool("version", false, "")
|
|
global.Bool("no-color", false, "")
|
|
if err := global.Parse(args); err != nil {
|
|
return usageErr(err)
|
|
}
|
|
if *versionFlag {
|
|
_, _ = io.WriteString(stdout, version+"\n")
|
|
return nil
|
|
}
|
|
rest := global.Args()
|
|
if len(rest) == 0 || rest[0] == "help" {
|
|
printUsage(stdout)
|
|
return nil
|
|
}
|
|
level := slog.LevelInfo
|
|
if *quiet {
|
|
level = slog.LevelError
|
|
}
|
|
if *verbose {
|
|
level = slog.LevelDebug
|
|
}
|
|
runtime := &runtime{
|
|
ctx: ctx,
|
|
configPath: config.ResolvePath(*configPath),
|
|
stdout: stdout,
|
|
stderr: stderr,
|
|
json: *jsonOut,
|
|
plain: *plainOut,
|
|
logger: slog.New(slog.NewTextHandler(stderr, &slog.HandlerOptions{Level: level})),
|
|
}
|
|
return runtime.dispatch(rest)
|
|
}
|
|
|
|
type runtime struct {
|
|
ctx context.Context
|
|
configPath string
|
|
cfg config.Config
|
|
stdout io.Writer
|
|
stderr io.Writer
|
|
json bool
|
|
plain bool
|
|
logger *slog.Logger
|
|
store *store.Store
|
|
client discordClient
|
|
syncer syncService
|
|
dbLockHeld bool
|
|
openStore func(context.Context, string) (*store.Store, error)
|
|
newDiscord func(config.Config) (discordClient, error)
|
|
newSyncer func(syncer.Client, *store.Store, *slog.Logger) syncService
|
|
newEmbed func(config.EmbeddingsConfig) (embed.Provider, error)
|
|
now func() time.Time
|
|
}
|
|
|
|
type discordClient interface {
|
|
syncer.Client
|
|
Close() error
|
|
Self(context.Context) (*discordgo.User, error)
|
|
Guilds(context.Context) ([]*discordgo.UserGuild, error)
|
|
}
|
|
|
|
type syncService interface {
|
|
DiscoverGuilds(context.Context) ([]*discordgo.UserGuild, error)
|
|
Sync(context.Context, syncer.SyncOptions) (syncer.SyncStats, error)
|
|
RunTail(context.Context, []string, time.Duration) error
|
|
}
|
|
|
|
type attachmentTextConfigurer interface {
|
|
SetAttachmentTextEnabled(bool)
|
|
}
|
|
|
|
func (r *runtime) dispatch(rest []string) error {
|
|
switch rest[0] {
|
|
case "init":
|
|
return r.runInit(rest[1:])
|
|
case "sync":
|
|
return r.withLocalStoreLocked(true, func() error { return r.runSync(rest[1:]) })
|
|
case "tail":
|
|
return r.withServicesLocked(true, func() error { return r.runTail(rest[1:]) })
|
|
case "wiretap":
|
|
return r.withLocalStoreLocked(false, func() error { return r.runWiretap(rest[1:]) })
|
|
case "search":
|
|
autoShareUpdate := !hasBoolFlag(rest[1:], "--dm")
|
|
return r.withLocalStoreDefaultLocked(autoShareUpdate, autoShareUpdate, func() error { return r.runSearch(rest[1:]) })
|
|
case "messages":
|
|
if hasBoolFlag(rest[1:], "--sync") && !hasBoolFlag(rest[1:], "--dm") {
|
|
return r.withServicesAutoLocked(true, true, true, func() error { return r.runMessages(rest[1:]) })
|
|
}
|
|
autoShareUpdate := !hasBoolFlag(rest[1:], "--dm")
|
|
return r.withLocalStoreDefaultLocked(autoShareUpdate, autoShareUpdate, func() error { return r.runMessages(rest[1:]) })
|
|
case "digest":
|
|
return r.withLocalStoreDefaultLocked(true, true, func() error { return r.runDigest(rest[1:]) })
|
|
case "analytics":
|
|
return r.runAnalytics(rest[1:])
|
|
case "dms":
|
|
return r.withLocalStoreDefault(false, func() error { return r.runDirectMessages(rest[1:]) })
|
|
case "mentions":
|
|
return r.withLocalStoreLocked(true, func() error { return r.runMentions(rest[1:]) })
|
|
case "embed":
|
|
return r.withLocalStoreLocked(true, func() error { return r.runEmbed(rest[1:]) })
|
|
case "sql":
|
|
return r.withLocalStoreLocked(true, func() error { return r.runSQL(rest[1:]) })
|
|
case "members":
|
|
return r.withLocalStoreLocked(true, func() error { return r.runMembers(rest[1:]) })
|
|
case "channels":
|
|
return r.withLocalStoreLocked(true, func() error { return r.runChannels(rest[1:]) })
|
|
case "status":
|
|
return r.withLocalStoreLocked(true, func() error { return r.runStatus(rest[1:]) })
|
|
case "report":
|
|
return r.withLocalStoreLocked(true, func() error { return r.runReport(rest[1:]) })
|
|
case "publish":
|
|
return r.withServicesAutoLocked(false, false, true, func() error { return r.runPublish(rest[1:]) })
|
|
case "subscribe":
|
|
return r.runSubscribe(rest[1:])
|
|
case "update":
|
|
return r.withServicesAutoLocked(false, false, true, func() error { return r.runUpdate(rest[1:]) })
|
|
case "doctor":
|
|
return r.runDoctor(rest[1:])
|
|
default:
|
|
return usageErr(fmt.Errorf("unknown command %q", rest[0]))
|
|
}
|
|
}
|
|
|
|
func (r *runtime) withServices(withDiscord bool, fn func() error) error {
|
|
return r.withServicesAuto(withDiscord, !withDiscord, fn)
|
|
}
|
|
|
|
func (r *runtime) withServicesLocked(withDiscord bool, fn func() error) error {
|
|
return r.withServicesAutoLocked(withDiscord, !withDiscord, true, fn)
|
|
}
|
|
|
|
func (r *runtime) withLocalStoreLocked(autoShareUpdate bool, fn func() error) error {
|
|
return r.withLocalStoreDefaultLocked(autoShareUpdate, true, fn)
|
|
}
|
|
|
|
func (r *runtime) withLocalStoreDefault(autoShareUpdate bool, fn func() error) error {
|
|
return r.withLocalStoreDefaultLocked(autoShareUpdate, false, fn)
|
|
}
|
|
|
|
func (r *runtime) withLocalStoreDefaultLocked(autoShareUpdate, lockDB bool, fn func() error) error {
|
|
cfg, err := config.Load(r.configPath)
|
|
if err != nil {
|
|
if !errors.Is(err, os.ErrNotExist) {
|
|
return configErr(err)
|
|
}
|
|
cfg = config.Default()
|
|
if err := cfg.Normalize(); err != nil {
|
|
return configErr(err)
|
|
}
|
|
}
|
|
if err := config.EnsureRuntimeDirs(cfg); err != nil {
|
|
return configErr(err)
|
|
}
|
|
dbPath, err := config.ExpandPath(cfg.DBPath)
|
|
if err != nil {
|
|
return configErr(err)
|
|
}
|
|
r.cfg = cfg
|
|
if lockDB {
|
|
return r.withSyncLock(func() error {
|
|
return r.openLocalStore(dbPath, autoShareUpdate, fn)
|
|
})
|
|
}
|
|
return r.openLocalStore(dbPath, autoShareUpdate, fn)
|
|
}
|
|
|
|
func (r *runtime) openLocalStore(dbPath string, autoShareUpdate bool, fn func() error) error {
|
|
storeFactory := r.openStore
|
|
if storeFactory == nil {
|
|
storeFactory = store.Open
|
|
}
|
|
var err error
|
|
r.store, err = storeFactory(r.ctx, dbPath)
|
|
if err != nil {
|
|
return dbErr(err)
|
|
}
|
|
defer func() { _ = r.store.Close() }()
|
|
if autoShareUpdate && os.Getenv("DISCRAWL_NO_AUTO_UPDATE") != "1" {
|
|
if err := r.autoUpdateShare(); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return fn()
|
|
}
|
|
|
|
func (r *runtime) withServicesAuto(withDiscord, autoShareUpdate bool, fn func() error) error {
|
|
return r.withServicesAutoLocked(withDiscord, autoShareUpdate, false, fn)
|
|
}
|
|
|
|
func (r *runtime) withServicesAutoLocked(withDiscord, autoShareUpdate, lockDB bool, fn func() error) error {
|
|
cfg, err := config.Load(r.configPath)
|
|
if err != nil {
|
|
return configErr(err)
|
|
}
|
|
if err := config.EnsureRuntimeDirs(cfg); err != nil {
|
|
return configErr(err)
|
|
}
|
|
dbPath, err := config.ExpandPath(cfg.DBPath)
|
|
if err != nil {
|
|
return configErr(err)
|
|
}
|
|
r.cfg = cfg
|
|
if lockDB {
|
|
return r.withSyncLock(func() error {
|
|
return r.openServices(dbPath, withDiscord, autoShareUpdate, fn)
|
|
})
|
|
}
|
|
return r.openServices(dbPath, withDiscord, autoShareUpdate, fn)
|
|
}
|
|
|
|
func (r *runtime) openServices(dbPath string, withDiscord, autoShareUpdate bool, fn func() error) error {
|
|
storeFactory := r.openStore
|
|
if storeFactory == nil {
|
|
storeFactory = store.Open
|
|
}
|
|
var err error
|
|
r.store, err = storeFactory(r.ctx, dbPath)
|
|
if err != nil {
|
|
return dbErr(err)
|
|
}
|
|
defer func() { _ = r.store.Close() }()
|
|
if autoShareUpdate && os.Getenv("DISCRAWL_NO_AUTO_UPDATE") != "1" {
|
|
if err := r.autoUpdateShare(); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
if withDiscord {
|
|
if err := r.ensureDiscordServices(); err != nil {
|
|
return err
|
|
}
|
|
if r.client != nil {
|
|
defer func() { _ = r.client.Close() }()
|
|
}
|
|
}
|
|
return fn()
|
|
}
|
|
|
|
func (r *runtime) ensureDiscordServices() error {
|
|
discordFactory := r.newDiscord
|
|
if discordFactory == nil {
|
|
discordFactory = func(cfg config.Config) (discordClient, error) {
|
|
token, err := config.ResolveDiscordToken(cfg)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return discord.New(token.Token)
|
|
}
|
|
}
|
|
client, err := discordFactory(r.cfg)
|
|
if err != nil {
|
|
return authErr(err)
|
|
}
|
|
r.client = client
|
|
syncerFactory := r.newSyncer
|
|
if syncerFactory == nil {
|
|
syncerFactory = func(client syncer.Client, s *store.Store, logger *slog.Logger) syncService {
|
|
return syncer.New(client, s, logger)
|
|
}
|
|
}
|
|
r.syncer = syncerFactory(r.client, r.store, r.logger)
|
|
if configurable, ok := r.syncer.(attachmentTextConfigurer); ok {
|
|
configurable.SetAttachmentTextEnabled(r.cfg.AttachmentTextEnabled())
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (r *runtime) autoUpdateShare() error {
|
|
if !r.cfg.ShareEnabled() || !r.cfg.Share.AutoUpdate {
|
|
return nil
|
|
}
|
|
staleAfter, err := time.ParseDuration(r.cfg.Share.StaleAfter)
|
|
if err != nil {
|
|
return configErr(fmt.Errorf("invalid share.stale_after: %w", err))
|
|
}
|
|
if !share.NeedsImport(r.ctx, r.store, staleAfter) {
|
|
return nil
|
|
}
|
|
opts, err := r.shareOptions()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if err := share.Pull(r.ctx, opts); err != nil {
|
|
return err
|
|
}
|
|
_, _, err = share.ImportIfChanged(r.ctx, r.store, opts)
|
|
if errors.Is(err, share.ErrNoManifest) {
|
|
return nil
|
|
}
|
|
return err
|
|
}
|
|
|
|
func (r *runtime) shareOptions() (share.Options, error) {
|
|
repoPath, err := config.ExpandPath(r.cfg.Share.RepoPath)
|
|
if err != nil {
|
|
return share.Options{}, configErr(err)
|
|
}
|
|
return share.Options{
|
|
RepoPath: repoPath,
|
|
Remote: r.cfg.Share.Remote,
|
|
Branch: r.cfg.Share.Branch,
|
|
}, nil
|
|
}
|