diff --git a/internal/notionapi/api.go b/internal/notionapi/api.go index 03df144..9255681 100644 --- a/internal/notionapi/api.go +++ b/internal/notionapi/api.go @@ -47,45 +47,50 @@ func (c Client) Sync(ctx context.Context, st *store.Store) (Summary, error) { c.HTTP = http.DefaultClient } var s Summary - users, err := c.listUsers(ctx) - if err != nil { - return s, err - } - for _, u := range users { - raw := notiontext.MarshalRaw(u) - if err := st.UpsertUser(ctx, store.User{ - ID: u.string("id"), Name: userName(u), Email: userEmail(u), RawJSON: raw, Source: SourceName, SyncedAt: store.NowMS(), - }); err != nil { - return s, err - } - s.Users++ - } - pages, err := c.searchPages(ctx) - if err != nil { - return s, err - } - for _, page := range pages { - count, comments, err := c.ingestPage(ctx, st, page, ingestPageOptions{FetchBlocks: true, FetchComments: true}) + if err := st.DeferPageFTS(ctx, func() error { + users, err := c.listUsers(ctx) if err != nil { - return s, err + return err } - s.Pages++ - s.Blocks += count - s.Comments += comments - } - collections, err := c.searchCollections(ctx) - if err != nil { - return s, err - } - for _, collection := range collections { - rows, err := c.ingestCollection(ctx, st, collection) + for _, u := range users { + raw := notiontext.MarshalRaw(u) + if err := st.UpsertUser(ctx, store.User{ + ID: u.string("id"), Name: userName(u), Email: userEmail(u), RawJSON: raw, Source: SourceName, SyncedAt: store.NowMS(), + }); err != nil { + return err + } + s.Users++ + } + pages, err := c.searchPages(ctx) if err != nil { - return s, err + return err } - s.Databases++ - s.DatabaseRows += rows - } - if err := st.SetSyncState(ctx, SourceName, "workspace", "default", time.Now().Format(time.RFC3339)); err != nil { + for _, page := range pages { + count, comments, err := c.ingestPage(ctx, st, page, ingestPageOptions{FetchBlocks: true, FetchComments: true}) + if err != nil { + return err + } + s.Pages++ + s.Blocks += count + s.Comments += comments + } + collections, err := c.searchCollections(ctx) + if err != nil { + return err + } + for _, collection := range collections { + rows, err := c.ingestCollection(ctx, st, collection) + if err != nil { + return err + } + s.Databases++ + s.DatabaseRows += rows + } + if err := st.SetSyncState(ctx, SourceName, "workspace", "default", time.Now().Format(time.RFC3339)); err != nil { + return err + } + return nil + }); err != nil { return s, err } return s, nil diff --git a/internal/notiondesktop/desktop.go b/internal/notiondesktop/desktop.go index c345982..d5de3ee 100644 --- a/internal/notiondesktop/desktop.go +++ b/internal/notiondesktop/desktop.go @@ -63,22 +63,27 @@ func Ingest(ctx context.Context, st *store.Store, path, cacheDir string) (Summar } defer db.Close() s := Summary{Source: source} - if s.Spaces, err = ingestSpaces(ctx, st, db); err != nil { - return s, err - } - if s.Users, err = ingestUsers(ctx, st, db); err != nil { - return s, err - } - if s.Teams, err = ingestTeams(ctx, st, db); err != nil { - return s, err - } - if s.Collections, err = ingestCollections(ctx, st, db); err != nil { - return s, err - } - if s.Pages, s.Blocks, s.RawRecords, err = ingestBlocks(ctx, st, db); err != nil { - return s, err - } - if s.Comments, err = ingestComments(ctx, st, db); err != nil { + if err := st.DeferPageFTS(ctx, func() error { + if s.Spaces, err = ingestSpaces(ctx, st, db); err != nil { + return err + } + if s.Users, err = ingestUsers(ctx, st, db); err != nil { + return err + } + if s.Teams, err = ingestTeams(ctx, st, db); err != nil { + return err + } + if s.Collections, err = ingestCollections(ctx, st, db); err != nil { + return err + } + if s.Pages, s.Blocks, s.RawRecords, err = ingestBlocks(ctx, st, db); err != nil { + return err + } + if s.Comments, err = ingestComments(ctx, st, db); err != nil { + return err + } + return nil + }); err != nil { return s, err } if err := st.SetSyncState(ctx, SourceName, "desktop", "notion.db", snapshot); err != nil { diff --git a/internal/store/store.go b/internal/store/store.go index a0df2a8..fa9ba1c 100644 --- a/internal/store/store.go +++ b/internal/store/store.go @@ -17,8 +17,10 @@ import ( const schemaVersion = 1 type Store struct { - db *sql.DB - path string + db *sql.DB + path string + deferredFTS int + deferredFTSPages map[string]bool } func Open(path string) (*Store, error) { @@ -379,7 +381,7 @@ func (s *Store) UpsertPage(ctx context.Context, x Page) error { if err != nil { return err } - return s.refreshPageFTS(ctx, x.ID) + return s.markPageFTS(ctx, x.ID) } func (s *Store) UpsertBlock(ctx context.Context, x Block) error { @@ -410,7 +412,7 @@ func (s *Store) UpsertBlock(ctx context.Context, x Block) error { return err } if x.PageID != "" { - return s.refreshPageFTS(ctx, x.PageID) + return s.markPageFTS(ctx, x.PageID) } return nil } @@ -461,6 +463,44 @@ func (s *Store) SetSyncState(ctx context.Context, source, entityType, entityID, return err } +func (s *Store) DeferPageFTS(ctx context.Context, fn func() error) error { + outer := s.deferredFTS == 0 + if outer { + s.deferredFTSPages = map[string]bool{} + } + s.deferredFTS++ + err := fn() + s.deferredFTS-- + if !outer { + return err + } + pages := s.deferredFTSPages + s.deferredFTSPages = nil + if err != nil { + return err + } + for pageID := range pages { + if err := s.refreshPageFTS(ctx, pageID); err != nil { + return err + } + } + return nil +} + +func (s *Store) markPageFTS(ctx context.Context, pageID string) error { + if pageID == "" { + return nil + } + if s.deferredFTS > 0 { + if s.deferredFTSPages == nil { + s.deferredFTSPages = map[string]bool{} + } + s.deferredFTSPages[pageID] = true + return nil + } + return s.refreshPageFTS(ctx, pageID) +} + func (s *Store) refreshPageFTS(ctx context.Context, pageID string) error { var title string if err := s.db.QueryRowContext(ctx, `select title from pages where id = ?`, pageID).Scan(&title); err != nil { diff --git a/internal/store/store_test.go b/internal/store/store_test.go index 49f0679..a0a86ec 100644 --- a/internal/store/store_test.go +++ b/internal/store/store_test.go @@ -33,6 +33,42 @@ func TestStoreUpsertsAndSearchesPage(t *testing.T) { } } +func TestStoreDefersPageFTSRefresh(t *testing.T) { + st, err := Open(filepath.Join(t.TempDir(), "notcrawl.db")) + if err != nil { + t.Fatal(err) + } + defer st.Close() + ctx := context.Background() + now := NowMS() + err = st.DeferPageFTS(ctx, func() error { + if err := st.UpsertPage(ctx, Page{ID: "page1", Title: "Launch Plan", Alive: true, Source: "test", SyncedAt: now}); err != nil { + return err + } + if err := st.UpsertBlock(ctx, Block{ID: "block1", PageID: "page1", Type: "text", Text: "deferred sqlite refresh", Alive: true, Source: "test", SyncedAt: now}); err != nil { + return err + } + results, err := st.Search(ctx, "sqlite", 10) + if err != nil { + return err + } + if len(results) != 0 { + t.Fatalf("expected deferred FTS to stay stale inside callback, got %+v", results) + } + return nil + }) + if err != nil { + t.Fatal(err) + } + results, err := st.Search(ctx, "sqlite", 10) + if err != nil { + t.Fatal(err) + } + if len(results) != 1 || results[0].ID != "page1" { + t.Fatalf("expected refreshed FTS after callback, got %+v", results) + } +} + func TestStoreOrdersBlocksByDisplayOrder(t *testing.T) { st, err := Open(filepath.Join(t.TempDir(), "notcrawl.db")) if err != nil {