diff --git a/internal/cmd/drive.go b/internal/cmd/drive.go index bbd1bf3..a5bc042 100644 --- a/internal/cmd/drive.go +++ b/internal/cmd/drive.go @@ -99,6 +99,8 @@ type DriveSearchCmd struct { Max int64 `name:"max" aliases:"limit" help:"Max results" default:"20"` Page string `name:"page" aliases:"cursor" help:"Page token"` AllDrives bool `name:"all-drives" help:"Include shared drives (default: true; use --no-all-drives for My Drive only)" default:"true" negatable:"_"` + Drive string `name:"drive" aliases:"drive-id" help:"Scope search to a specific shared drive (uses corpora=drive with driveId). Mutually exclusive with --no-all-drives. Pass the driveId from 'gog drive drives'."` + Parent string `name:"parent" help:"Scope search to direct children of a specific folder or shared drive. Wraps the query with \"'' in parents\"."` } type DriveGetCmd struct { @@ -963,10 +965,15 @@ func downloadDriveFile(ctx context.Context, svc *drive.Service, meta *drive.File return outPath, n, nil } -func driveFilesListCallWithDriveSupport(call *drive.FilesListCall, allDrives bool) *drive.FilesListCall { +func driveFilesListCallWithDriveSupport(call *drive.FilesListCall, allDrives bool, driveID string) *drive.FilesListCall { // SupportsAllDrives must be set for shared drive file IDs to behave correctly. call = call.SupportsAllDrives(true).IncludeItemsFromAllDrives(allDrives) - if allDrives { + if driveID != "" { + // Scoped search within a specific shared drive. The Drive API requires + // corpora=drive + driveId together, and includeItemsFromAllDrives=true — + // which is why callers must guard against driveID!="" with allDrives=false. + call = call.Corpora("drive").DriveId(driveID) + } else if allDrives { call = call.Corpora("allDrives") } return call diff --git a/internal/cmd/drive_listing.go b/internal/cmd/drive_listing.go index cac55ca..1ee4535 100644 --- a/internal/cmd/drive_listing.go +++ b/internal/cmd/drive_listing.go @@ -17,6 +17,7 @@ type driveFileListOptions struct { max int64 page string allDrives bool + driveID string } func (c *DriveLsCmd) Run(ctx context.Context, flags *RootFlags) error { @@ -58,16 +59,29 @@ func (c *DriveSearchCmd) Run(ctx context.Context, flags *RootFlags) error { return usage("missing query") } + if c.Drive != "" && !c.AllDrives { + return usage("--drive cannot be combined with --no-all-drives") + } + if c.Parent != "" && c.RawQuery { + return usage("--parent cannot be combined with --raw-query; include the \"'' in parents\" clause in your raw query instead") + } + _, svc, err := requireDriveService(ctx, flags) if err != nil { return err } + finalQuery := buildDriveSearchQuery(query, c.RawQuery) + if c.Parent != "" { + finalQuery = fmt.Sprintf("'%s' in parents and %s", c.Parent, finalQuery) + } + resp, err := listDriveFiles(ctx, svc, driveFileListOptions{ - query: buildDriveSearchQuery(query, c.RawQuery), + query: finalQuery, max: c.Max, page: c.Page, allDrives: c.AllDrives, + driveID: c.Drive, }) if err != nil { return err @@ -82,7 +96,7 @@ func listDriveFiles(ctx context.Context, svc *drive.Service, opts driveFileListO PageSize(opts.max). PageToken(opts.page). OrderBy("modifiedTime desc") - call = driveFilesListCallWithDriveSupport(call, opts.allDrives) + call = driveFilesListCallWithDriveSupport(call, opts.allDrives, opts.driveID) return call.Fields(driveFileListFields).Context(ctx).Do() } diff --git a/internal/cmd/drive_search_more_test.go b/internal/cmd/drive_search_more_test.go index 53ada3c..f15cc8a 100644 --- a/internal/cmd/drive_search_more_test.go +++ b/internal/cmd/drive_search_more_test.go @@ -314,3 +314,230 @@ func TestDriveSearchCmd_RawQueryBypassesFullTextWrapping(t *testing.T) { t.Fatalf("execute: %v", execErr) } } + +func TestDriveSearchCmd_WithDrive(t *testing.T) { + origNew := newDriveService + t.Cleanup(func() { newDriveService = origNew }) + + const wantDriveID = "0AFakeSharedDriveID" + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.NotFound(w, r) + return + } + path := strings.TrimPrefix(r.URL.Path, "/drive/v3") + if path != "/files" { + http.NotFound(w, r) + return + } + q := r.URL.Query() + if q.Get("supportsAllDrives") != "true" { + http.Error(w, "missing supportsAllDrives=true", http.StatusBadRequest) + return + } + if q.Get("includeItemsFromAllDrives") != "true" { + http.Error(w, "missing includeItemsFromAllDrives=true", http.StatusBadRequest) + return + } + if got := q.Get("corpora"); got != "drive" { + http.Error(w, "want corpora=drive, got "+got, http.StatusBadRequest) + return + } + if got := q.Get("driveId"); got != wantDriveID { + http.Error(w, "want driveId="+wantDriveID+", got "+got, http.StatusBadRequest) + return + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{"files": []map[string]any{}}) + })) + t.Cleanup(srv.Close) + + svc, err := drive.NewService(context.Background(), + option.WithoutAuthentication(), + option.WithHTTPClient(srv.Client()), + option.WithEndpoint(srv.URL+"/"), + ) + if err != nil { + t.Fatalf("NewService: %v", err) + } + newDriveService = func(context.Context, string) (*drive.Service, error) { return svc, nil } + + flags := &RootFlags{Account: "a@b.com"} + u, uiErr := ui.New(ui.Options{Stdout: io.Discard, Stderr: io.Discard, Color: "never"}) + if uiErr != nil { + t.Fatalf("ui.New: %v", uiErr) + } + ctx := ui.WithUI(context.Background(), u) + + cmd := &DriveSearchCmd{} + if execErr := runKong(t, cmd, []string{"hello", "--drive", wantDriveID}, ctx, flags); execErr != nil { + t.Fatalf("execute: %v", execErr) + } +} + +func TestDriveSearchCmd_WithParent(t *testing.T) { + origNew := newDriveService + t.Cleanup(func() { newDriveService = origNew }) + + const parentID = "1FakeFolderID" + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.NotFound(w, r) + return + } + path := strings.TrimPrefix(r.URL.Path, "/drive/v3") + if path != "/files" { + http.NotFound(w, r) + return + } + if errMsg := driveAllDrivesQueryError(r, true); errMsg != "" { + http.Error(w, errMsg, http.StatusBadRequest) + return + } + got := r.URL.Query().Get("q") + if !strings.Contains(got, "'"+parentID+"' in parents") { + http.Error(w, "missing parent clause in q: "+got, http.StatusBadRequest) + return + } + if !strings.Contains(got, "fullText contains 'hello'") { + http.Error(w, "missing fullText clause in q: "+got, http.StatusBadRequest) + return + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{"files": []map[string]any{}}) + })) + t.Cleanup(srv.Close) + + svc, err := drive.NewService(context.Background(), + option.WithoutAuthentication(), + option.WithHTTPClient(srv.Client()), + option.WithEndpoint(srv.URL+"/"), + ) + if err != nil { + t.Fatalf("NewService: %v", err) + } + newDriveService = func(context.Context, string) (*drive.Service, error) { return svc, nil } + + flags := &RootFlags{Account: "a@b.com"} + u, uiErr := ui.New(ui.Options{Stdout: io.Discard, Stderr: io.Discard, Color: "never"}) + if uiErr != nil { + t.Fatalf("ui.New: %v", uiErr) + } + ctx := ui.WithUI(context.Background(), u) + + cmd := &DriveSearchCmd{} + if execErr := runKong(t, cmd, []string{"hello", "--parent", parentID}, ctx, flags); execErr != nil { + t.Fatalf("execute: %v", execErr) + } +} + +func TestDriveSearchCmd_DriveAndParent_Combine(t *testing.T) { + origNew := newDriveService + t.Cleanup(func() { newDriveService = origNew }) + + const driveID = "0AFakeSharedDriveID" + const parentID = "1FakeFolderID" + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.NotFound(w, r) + return + } + path := strings.TrimPrefix(r.URL.Path, "/drive/v3") + if path != "/files" { + http.NotFound(w, r) + return + } + q := r.URL.Query() + if got := q.Get("corpora"); got != "drive" { + http.Error(w, "want corpora=drive, got "+got, http.StatusBadRequest) + return + } + if got := q.Get("driveId"); got != driveID { + http.Error(w, "want driveId="+driveID+", got "+got, http.StatusBadRequest) + return + } + got := q.Get("q") + if !strings.Contains(got, "'"+parentID+"' in parents") { + http.Error(w, "missing parent clause in q: "+got, http.StatusBadRequest) + return + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{"files": []map[string]any{}}) + })) + t.Cleanup(srv.Close) + + svc, err := drive.NewService(context.Background(), + option.WithoutAuthentication(), + option.WithHTTPClient(srv.Client()), + option.WithEndpoint(srv.URL+"/"), + ) + if err != nil { + t.Fatalf("NewService: %v", err) + } + newDriveService = func(context.Context, string) (*drive.Service, error) { return svc, nil } + + flags := &RootFlags{Account: "a@b.com"} + u, uiErr := ui.New(ui.Options{Stdout: io.Discard, Stderr: io.Discard, Color: "never"}) + if uiErr != nil { + t.Fatalf("ui.New: %v", uiErr) + } + ctx := ui.WithUI(context.Background(), u) + + cmd := &DriveSearchCmd{} + if execErr := runKong(t, cmd, []string{"hello", "--drive", driveID, "--parent", parentID}, ctx, flags); execErr != nil { + t.Fatalf("execute: %v", execErr) + } +} + +func TestDriveSearchCmd_DriveAndNoAllDrives_Conflicts(t *testing.T) { + origNew := newDriveService + t.Cleanup(func() { newDriveService = origNew }) + newDriveService = func(context.Context, string) (*drive.Service, error) { + t.Fatal("newDriveService should not be called when flags conflict") + return nil, nil + } + + flags := &RootFlags{Account: "a@b.com"} + u, uiErr := ui.New(ui.Options{Stdout: io.Discard, Stderr: io.Discard, Color: "never"}) + if uiErr != nil { + t.Fatalf("ui.New: %v", uiErr) + } + ctx := ui.WithUI(context.Background(), u) + + cmd := &DriveSearchCmd{} + err := runKong(t, cmd, []string{"hello", "--drive", "0AFake", "--no-all-drives"}, ctx, flags) + if err == nil { + t.Fatalf("expected error for --drive with --no-all-drives, got nil") + } + if !strings.Contains(err.Error(), "--drive") || !strings.Contains(err.Error(), "--no-all-drives") { + t.Fatalf("error should mention conflicting flags, got: %v", err) + } +} + +func TestDriveSearchCmd_ParentAndRawQuery_Conflicts(t *testing.T) { + origNew := newDriveService + t.Cleanup(func() { newDriveService = origNew }) + newDriveService = func(context.Context, string) (*drive.Service, error) { + t.Fatal("newDriveService should not be called when flags conflict") + return nil, nil + } + + flags := &RootFlags{Account: "a@b.com"} + u, uiErr := ui.New(ui.Options{Stdout: io.Discard, Stderr: io.Discard, Color: "never"}) + if uiErr != nil { + t.Fatalf("ui.New: %v", uiErr) + } + ctx := ui.WithUI(context.Background(), u) + + cmd := &DriveSearchCmd{} + err := runKong(t, cmd, []string{"someQuery", "--parent", "1FakeFolder", "--raw-query"}, ctx, flags) + if err == nil { + t.Fatalf("expected error for --parent with --raw-query, got nil") + } + if !strings.Contains(err.Error(), "--parent") || !strings.Contains(err.Error(), "--raw-query") { + t.Fatalf("error should mention conflicting flags, got: %v", err) + } +}