Signal-iOS/Scripts/precommit.py
Evan Hahn 1b00741b6d
Fix remaining SwiftLint failures, lint more strictly
This fixes our remaining SwiftLint violations, which were small.

It also updates the precommit script to fail if any violations are
found, even warnings. This will cause CI to fail if you include a file
that isn't SwiftLint-compatible.
2022-10-07 12:00:26 -05:00

399 lines
11 KiB
Python
Executable File

#!/usr/bin/env python3
import os
import sys
import subprocess
import datetime
import argparse
from typing import Iterable
from pathlib import Path
EXTENSIONS_TO_CHECK = set((".h", ".hpp", ".cpp", ".m", ".mm", ".pch", ".swift"))
git_repo_path = os.path.abspath(
subprocess.check_output(["git", "rev-parse", "--show-toplevel"], text=True).strip()
)
def sort_forward_decl_statement_block(text, filepath, filename, file_extension):
lines = text.split("\n")
lines = [line.strip() for line in lines if line.strip()]
lines = list(set(lines))
lines.sort()
return "\n" + "\n".join(lines) + "\n"
def find_matching_section(text, match_test):
lines = text.split("\n")
first_matching_line_index = None
for index, line in enumerate(lines):
if match_test(line):
first_matching_line_index = index
break
if first_matching_line_index is None:
return None
# Absorb any leading empty lines.
while first_matching_line_index > 0:
prev_line = lines[first_matching_line_index - 1]
if prev_line.strip():
break
first_matching_line_index = first_matching_line_index - 1
first_non_matching_line_index = None
for index, line in enumerate(lines[first_matching_line_index:]):
if not line.strip():
# Absorb any trailing empty lines.
continue
if not match_test(line):
first_non_matching_line_index = index + first_matching_line_index
break
text0 = "\n".join(lines[:first_matching_line_index])
if first_non_matching_line_index is None:
text1 = "\n".join(lines[first_matching_line_index:])
text2 = None
else:
text1 = "\n".join(
lines[first_matching_line_index:first_non_matching_line_index]
)
text2 = "\n".join(lines[first_non_matching_line_index:])
return text0, text1, text2
def sort_matching_blocks(
sort_name, filepath, filename, file_extension, text, match_func, sort_func
):
unprocessed = text
processed = None
while True:
section = find_matching_section(unprocessed, match_func)
# print '\t', 'sort_matching_blocks', section
if not section:
if processed:
processed = "\n".join(
(
processed,
unprocessed,
)
)
else:
processed = unprocessed
break
text0, text1, text2 = section
if processed:
processed = "\n".join(
(
processed,
text0,
)
)
else:
processed = text0
# print 'before:'
# temp_lines = text1.split('\n')
# for index, line in enumerate(temp_lines):
# if index < 3 or index + 3 >= len(temp_lines):
# print '\t', index, line
# # print text1
# print
text1 = sort_func(text1, filepath, filename, file_extension)
# print 'after:'
# # print text1
# temp_lines = text1.split('\n')
# for index, line in enumerate(temp_lines):
# if index < 3 or index + 3 >= len(temp_lines):
# print '\t', index, line
# print
processed = "\n".join(
(
processed,
text1,
)
)
if text2:
unprocessed = text2
else:
break
if text != processed:
print(sort_name, filepath)
return processed
def find_forward_class_statement_section(text):
def is_forward_class_statement(line):
return line.strip().startswith("@class ")
return find_matching_section(text, is_forward_class_statement)
def find_forward_protocol_statement_section(text):
def is_forward_protocol_statement(line):
return line.strip().startswith("@protocol ") and line.strip().endswith(";")
return find_matching_section(text, is_forward_protocol_statement)
def sort_forward_class_statements(filepath, filename, file_extension, text):
# print 'sort_class_statements', filepath
if file_extension not in (".h", ".m", ".mm"):
return text
return sort_matching_blocks(
"sort_class_statements",
filepath,
filename,
file_extension,
text,
find_forward_class_statement_section,
sort_forward_decl_statement_block,
)
def sort_forward_protocol_statements(filepath, filename, file_extension, text):
# print 'sort_class_statements', filepath
if file_extension not in (".h", ".m", ".mm"):
return text
return sort_matching_blocks(
"sort_forward_protocol_statements",
filepath,
filename,
file_extension,
text,
find_forward_protocol_statement_section,
sort_forward_decl_statement_block,
)
def get_ext(file: str) -> str:
return os.path.splitext(file)[1]
def process(filepath):
short_filepath = filepath[len(git_repo_path) :]
if short_filepath.startswith(os.sep):
short_filepath = short_filepath[len(os.sep) :]
filename = os.path.basename(filepath)
if filename.startswith("."):
raise Exception("shouldn't call process with dotfile")
file_ext = get_ext(filename)
with open(filepath, "rt") as f:
text = f.read()
original_text = text
text = sort_forward_class_statements(filepath, filename, file_ext, text)
text = sort_forward_protocol_statements(filepath, filename, file_ext, text)
lines = text.split("\n")
shebang = ""
if lines[0].startswith("#!"):
shebang = lines[0] + "\n"
lines = lines[1:]
elif lines[0].startswith("// swift-tools-version:"):
shebang = lines[0] + "\n"
lines = lines[1:]
while lines and lines[0].startswith("//"):
lines = lines[1:]
text = "\n".join(lines)
text = text.strip()
header = """//
// Copyright (c) %s Open Whisper Systems. All rights reserved.
//
""" % (
datetime.datetime.now().year,
)
text = shebang + header + text + "\n"
if original_text == text:
return
print("Updating:", short_filepath)
with open(filepath, "wt") as f:
f.write(text)
def get_file_paths_in(path: str) -> Iterable[str]:
for rootdir, _, filenames in os.walk(path):
for filename in filenames:
yield os.path.abspath(os.path.join(rootdir, filename))
def get_file_paths_for_commands(commands: Iterable[list[str]]) -> Iterable[str]:
for command in commands:
lines = subprocess.check_output(command, text=True).split("\n")
for line in lines:
file_path = os.path.abspath(os.path.join(git_repo_path, line))
if os.path.exists(file_path):
yield file_path
def should_process_file(file_path: str) -> bool:
if get_ext(file_path) not in EXTENSIONS_TO_CHECK:
return False
for component in Path(file_path).parts:
if component.startswith("."):
return False
if component.endswith(".framework"):
return False
if component in (
"Pods",
"ThirdParty",
"Carthage",
):
return False
return True
def lint_swift_files(file_paths: set[str]) -> None:
swift_file_paths = list(filter(lambda f: get_ext(f) == ".swift", file_paths))
file_count = len(swift_file_paths)
if file_count < 1:
return
env = os.environ.copy()
env["SCRIPT_INPUT_FILE_COUNT"] = str(file_count)
for i, file_path in enumerate(swift_file_paths):
env[f"SCRIPT_INPUT_FILE_{i}"] = file_path
try:
lint_output = subprocess.check_output(
["swiftlint", "lint", "--fix", "--use-script-input-files"],
env=env,
text=True,
)
except subprocess.CalledProcessError as error:
lint_output = error.output
print(lint_output)
try:
lint_output = subprocess.check_output(
["swiftlint", "lint", "--strict", "--use-script-input-files"],
env=env,
text=True,
)
except subprocess.CalledProcessError as error:
lint_output = error.output
print(lint_output)
def check_diff_for_keywords():
objc_keywords = [
"OWSAbstractMethod\(",
"OWSAssert\(",
"OWSCAssert\(",
"OWSFail\(",
"OWSCFail\(",
"ows_add_overflow\(",
"ows_sub_overflow\(",
]
swift_keywords = [
"owsFail\(",
"precondition\(",
"fatalError\(",
"dispatchPrecondition\(",
"preconditionFailure\(",
"notImplemented\(",
]
keywords = objc_keywords + swift_keywords
matching_expression = "|".join(keywords)
command_line = (
'git diff --staged | grep --color=always -C 3 -E "%s"' % matching_expression
)
try:
output = subprocess.check_output(command_line, shell=True, text=True)
except subprocess.CalledProcessError as e:
# > man grep
# EXIT STATUS
# The grep utility exits with one of the following values:
# 0 One or more lines were selected.
# 1 No lines were selected.
# >1 An error occurred.
if e.returncode == 1:
# no keywords in diff output
return
else:
# some other error - bad grep expression?
raise e
if len(output) > 0:
print("⚠️ keywords detected in diff:")
print(output)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Precommit script.")
parser.add_argument(
"--all", action="store_true", help="process all files in or below current dir"
)
parser.add_argument("--path", help="used to specify a path to process.")
parser.add_argument(
"--ref", help="process all files that have changed since the given ref"
)
args = parser.parse_args()
all_file_paths: Iterable[str] = []
clang_format_commit = "HEAD"
if args.all:
all_file_paths = get_file_paths_in(git_repo_path)
elif args.path:
all_file_paths = get_file_paths_in(args.path)
elif args.ref:
all_file_paths = get_file_paths_for_commands(
[["git", "diff", "--name-only", "--diff-filter=ACMR", args.ref, "HEAD"]]
)
clang_format_commit = args.ref
else:
all_file_paths = get_file_paths_for_commands(
[
["git", "diff", "--cached", "--name-only", "--diff-filter=ACMR"],
["git", "diff", "--name-only", "--diff-filter=ACMR"],
]
)
file_paths = set(filter(should_process_file, all_file_paths))
lint_swift_files(file_paths)
for file_path in file_paths:
process(file_path)
print("Sorting Xcode project...")
print(
subprocess.getoutput(
'Scripts/sort-Xcode-project-file Signal.xcodeproj'
)
)
print("git clang-format...")
# we don't want to format .proto files, so we specify every other supported extension
print(
subprocess.getoutput(
'git clang-format --extensions "c,h,m,mm,cc,cp,cpp,c++,cxx,hh,hxx,cu,java,js,ts,cs" --commit %s'
% clang_format_commit
)
)
check_diff_for_keywords()