#!/usr/bin/env python3 import os import subprocess import argparse import re import json import sds_common from sds_common import fail import random # TODO: We should probably generate a class that knows how to set up # the database. It would: # # * Create all tables (or apply database schema). # * Register renamed classes. # [NSKeyedUnarchiver setClass:[OWSUserProfile class] forClassName:[OWSUserProfile collection]]; # [NSKeyedUnarchiver setClass:[OWSDatabaseMigration class] forClassName:[OWSDatabaseMigration collection]]; # We consider any subclass of TSYapDatabaseObject to be a "serializable model". # # We treat direct subclasses of TSYapDatabaseObject as "roots" of the model class hierarchy. # Only root models do deserialization. OLD_BASE_MODEL_CLASS_NAME = "TSYapDatabaseObject" NEW_BASE_MODEL_CLASS_NAME = "BaseModel" CODE_GEN_SNIPPET_MARKER_OBJC = "// --- CODE GENERATION MARKER" # GRDB seems to encode non-primitive using JSON. # GRDB chokes when decodes this JSON, due to it being a JSON "fragment". # Either this is a bug in GRDB or we're using GRDB incorrectly. # Until we resolve this issue, we need to encode/decode # non-primitives ourselves. USE_CODABLE_FOR_PRIMITIVES = False USE_CODABLE_FOR_NONPRIMITIVES = False def update_generated_snippet(file_path, marker, snippet): # file_path = sds_common.sds_from_relative_path(relative_path) if not os.path.exists(file_path): fail("Missing file:", file_path) with open(file_path, "rt") as f: text = f.read() start_index = text.find(marker) end_index = text.rfind(marker) if start_index < 0 or end_index < 0 or start_index >= end_index: fail(f"Could not find markers ('{marker}'): {file_path}") text = ( text[:start_index].strip() + "\n\n" + marker + "\n\n" + snippet + "\n\n" + marker + "\n\n" + text[end_index + len(marker) :].lstrip() ) sds_common.write_text_file_if_changed(file_path, text) def update_objc_snippet(file_path, snippet): snippet = sds_common.clean_up_generated_objc(snippet).strip() if len(snippet) < 1: return snippet = ( "// This snippet is generated by %s. Do not manually edit it, instead run `sds_codegen.sh`." % (sds_common.pretty_module_path(__file__),) + "\n\n" + snippet ) update_generated_snippet(file_path, CODE_GEN_SNIPPET_MARKER_OBJC, snippet) # ---- global_class_map = {} global_subclass_map = {} global_args = None # ---- def to_swift_identifier_name(identifier_name): return identifier_name[0].lower() + identifier_name[1:] class ParsedClass: def __init__(self, json_dict): self.name = json_dict.get("name") self.super_class_name = json_dict.get("super_class_name") self.filepath = sds_common.sds_from_relative_path(json_dict.get("filepath")) self.finalize_method_name = json_dict.get("finalize_method_name") self.property_map = {} for property_dict in json_dict.get("properties"): property = ParsedProperty(property_dict) property.class_name = self.name # TODO: We should handle all properties? if property.should_ignore_property(): continue self.property_map[property.name] = property def properties(self): result = [] for name in sorted(self.property_map.keys()): result.append(self.property_map[name]) return result def database_subclass_properties(self): # More than one subclass of a SDS model may declare properties # with the same name. This is fine, so long as they have # the same type. all_property_map = {} subclass_property_map = {} root_property_names = set() for property in self.properties(): all_property_map[property.name] = property root_property_names.add(property.name) for subclass in all_descendents_of_class(self): if should_ignore_class(subclass): continue for property in subclass.properties(): duplicate_property = all_property_map.get(property.name) if duplicate_property is not None: if ( property.swift_type_safe() != duplicate_property.swift_type_safe() ): print( "property:", property.class_name, property.name, property.swift_type_safe(), property.is_optional, ) print( "duplicate_property:", duplicate_property.class_name, duplicate_property.name, duplicate_property.swift_type_safe(), duplicate_property.is_optional, ) fail("Duplicate property doesn't match:", property.name) elif property.is_optional != duplicate_property.is_optional: if property.name in root_property_names: print( "property:", property.class_name, property.name, property.swift_type_safe(), property.is_optional, ) print( "duplicate_property:", duplicate_property.class_name, duplicate_property.name, duplicate_property.swift_type_safe(), duplicate_property.is_optional, ) fail("Duplicate property doesn't match:", property.name) # If one subclass property is optional and the other isn't, we should # treat both as optional for the purposes of the database schema. if not property.is_optional: continue else: continue all_property_map[property.name] = property subclass_property_map[property.name] = property result = [] for name in sorted(subclass_property_map.keys()): result.append(subclass_property_map[name]) return result def record_id_source(self): for property in self.properties(): if property.name == "sortId": return property.name return None def is_sds_model(self): if self.super_class_name is None: return False if not self.super_class_name in global_class_map: return False if self.super_class_name in ( OLD_BASE_MODEL_CLASS_NAME, NEW_BASE_MODEL_CLASS_NAME, ): return True super_class = global_class_map[self.super_class_name] return super_class.is_sds_model() def has_sds_superclass(self): return ( self.super_class_name and self.super_class_name in global_class_map and self.super_class_name != OLD_BASE_MODEL_CLASS_NAME and self.super_class_name != NEW_BASE_MODEL_CLASS_NAME ) def table_superclass(self): if self.super_class_name is None: return self if not self.super_class_name in global_class_map: return self if self.super_class_name == OLD_BASE_MODEL_CLASS_NAME: return self if self.super_class_name == NEW_BASE_MODEL_CLASS_NAME: return self super_class = global_class_map[self.super_class_name] return super_class.table_superclass() def all_superclass_names(self): result = [self.name] if self.super_class_name is not None: if self.super_class_name in global_class_map: super_class = global_class_map[self.super_class_name] result += super_class.all_superclass_names() return result def has_any_superclass_with_name(self, name): return name in self.all_superclass_names() def should_generate_extensions(self): if self.name in ( OLD_BASE_MODEL_CLASS_NAME, NEW_BASE_MODEL_CLASS_NAME, ): return False if should_ignore_class(self): return False if not self.is_sds_model(): # Only write serialization extensions for SDS models. return False # The migration should not be persisted in the data store. if self.name in ( "OWSDatabaseMigration", "YDBDatabaseMigration", "OWSResaveCollectionDBMigration", ): return False if self.super_class_name in ( "OWSDatabaseMigration", "YDBDatabaseMigration", "OWSResaveCollectionDBMigration", ): return False return True def record_name(self): return remove_prefix_from_class_name(self.name) + "Record" def sorted_record_properties(self): record_name = self.record_name() # If a property has a custom column source, we don't redundantly create a column for that column base_properties = [ property for property in self.properties() if not property.has_aliased_column_name() ] # If a property has a custom column source, we don't redundantly create a column for that column subclass_properties = [ property for property in self.database_subclass_properties() if not property.has_aliased_column_name() ] # We need to maintain a stable ordering of record properties # across migrations, e.g. adding new columns to the tables. # # First, we build a list of "model" properties. This is the # the superset of properties in the model base class and all # of its subclasses. # # NOTE: We punch two values onto these properties: # force_optional and property_order. record_properties = [] for property in base_properties: # Treat all enum properties as forced-optional, so that during # deserialization we can survive unexpected raw values. # # Except the special-cased ones. force_optional = property.type_info().is_enum force_optional = force_optional and property.name not in ["mentionNotificationMode", "storyViewMode"] property.force_optional = force_optional record_properties.append(property) for property in subclass_properties: # We must "force" subclass properties to be optional # since they don't apply to the base model and other # subclasses. property.force_optional = True record_properties.append(property) for property in record_properties: # Try to load the known "order" for each property. # # "Orders" are indices used to ensure a stable ordering. # We find the "orders" of all properties that already have # one. # # This will initially be nil for new properties # which have not yet been assigned an order. property.property_order = property_order_for_property(property, record_name) all_property_orders = [ property.property_order for property in record_properties if property.property_order ] # We determine the "next" order we would assign to any # new property without an order. next_property_order = 1 + ( max(all_property_orders) if len(all_property_orders) > 0 else 0 ) # Pre-sort model properties by name, so that if we add more # than one at a time they are nicely (and stable-y) sorted # in an attractive way. record_properties.sort(key=lambda value: value.name) # Now iterate over all model properties and assign an order # to any new properties without one. for property in record_properties: if property.property_order is None: property.property_order = next_property_order # We "set" the order in the mapping which is persisted # as JSON to ensure continuity. set_property_order_for_property( property, record_name, next_property_order ) next_property_order = next_property_order + 1 # Now sort the model properties, applying the ordering. record_properties.sort(key=lambda value: value.property_order) return record_properties class TypeInfo: def __init__( self, swift_type, objc_type, should_use_blob=False, is_codable=False, is_enum=False, field_override_column_type=None, field_override_record_swift_type=None, ): self._swift_type = swift_type self._objc_type = objc_type self.should_use_blob = should_use_blob self.is_codable = is_codable self.is_enum = is_enum self.field_override_column_type = field_override_column_type self.field_override_record_swift_type = field_override_record_swift_type def swift_type(self): return str(self._swift_type) def objc_type(self): return str(self._objc_type) # This defines the mapping of Swift types to database column types. # We'll be iterating on this mapping. # Note that we currently store all sub-models and collections (e.g. [String]) as a blob. # # TODO: def database_column_type(self, value_name): if self.field_override_column_type is not None: return self.field_override_column_type elif self.should_use_blob or self.is_codable: return ".blob" elif self.is_enum: return ".int" elif self._swift_type == "String": return ".unicodeString" elif self._objc_type == "NSDate *": # Persist dates as NSTimeInterval timeIntervalSince1970. return ".double" elif self._swift_type == "Date": # Persist dates as NSTimeInterval timeIntervalSince1970. fail( 'We should not use `Date` as a "swift type" since all NSDates are serialized as doubles.', self._swift_type, ) elif self._swift_type == "Data": return ".blob" elif self._swift_type in ("Boolouble", "Bool"): return ".int" elif self._swift_type in ("Double", "Float"): return ".double" elif self.is_numeric(): return ".int64" else: fail("Unknown type(1):", self._swift_type) def is_numeric(self): # TODO: We need to revisit how we serialize numeric types. return self._swift_type in ( # 'signed char', "Bool", "UInt64", "UInt", "Int64", "Int", "Int32", "UInt32", "Double", "Float", ) def should_cast_to_swift(self): if self._swift_type in ( "Bool", "Int64", "UInt64", ): return False return self.is_numeric() def deserialize_record_invocation( self, property, value_name, is_optional, did_force_optional ): value_expr = "record.%s" % (property.column_name(),) deserialization_optional = None deserialization_not_optional = None deserialization_conversion = "" if self._swift_type == "String": deserialization_not_optional = "required" elif self._objc_type == "NSDate *": pass elif self._swift_type == "Date": fail("Unknown type(0):", self._swift_type) elif self.is_codable: deserialization_not_optional = "required" elif self._swift_type == "Data": deserialization_optional = "optionalData" deserialization_not_optional = "required" elif self.is_numeric(): deserialization_optional = "optionalNumericAsNSNumber" deserialization_not_optional = "required" deserialization_conversion = ", conversion: { NSNumber(value: $0) }" initializer_param_type = self.swift_type() if is_optional: initializer_param_type = initializer_param_type + "?" # Special-case the unpacking of the auto-incremented # primary key. if value_expr == "record.id": value_expr = "%s(recordId)" % (initializer_param_type,) elif is_optional: if deserialization_optional is not None: value_expr = 'SDSDeserialization.%s(%s, name: "%s"%s)' % ( deserialization_optional, value_expr, value_name, deserialization_conversion, ) elif did_force_optional: if deserialization_not_optional is not None: value_expr = 'try SDSDeserialization.%s(%s, name: "%s")' % ( deserialization_not_optional, value_expr, value_name, ) else: # Do nothing; we don't need to unpack this non-optional. pass if value_name == "mentionNotificationMode": value_statement = ( "let %s: %s = TSThreadMentionNotificationMode(rawValue: %s) ?? .default" % ( value_name, "TSThreadMentionNotificationMode", value_expr, ) ) elif value_name == "storyViewMode": value_statement = ( "let %s: %s = TSThreadStoryViewMode(rawValue: %s) ?? .default" % ( value_name, "TSThreadStoryViewMode", value_expr, ) ) elif self.is_codable: value_statement = "let %s: %s = %s" % ( value_name, initializer_param_type, value_expr, ) elif self.should_use_blob: blob_name = "%sSerialized" % (str(value_name),) if is_optional or did_force_optional: serialized_statement = "let %s: Data? = %s" % ( blob_name, value_expr, ) else: serialized_statement = "let %s: Data = %s" % ( blob_name, value_expr, ) if is_optional: value_statement = ( 'let %s: %s? = try SDSDeserialization.optionalUnarchive(%s, name: "%s")' % ( value_name, self._swift_type, blob_name, value_name, ) ) else: value_statement = ( 'let %s: %s = try SDSDeserialization.unarchive(%s, name: "%s")' % ( value_name, self._swift_type, blob_name, value_name, ) ) return [ serialized_statement, value_statement, ] elif self.is_enum and did_force_optional and not is_optional: return [ "guard let %s: %s = %s else {" % ( value_name, initializer_param_type, value_expr, ), " throw SDSError.missingRequiredField()", "}", ] elif is_optional and self._objc_type == "NSNumber *": return [ "let %s: %s = %s" % ( value_name, "NSNumber?", value_expr, ), # 'let %sRaw = %s' % ( value_name, value_expr, ), # 'var %s : NSNumber?' % ( value_name, ), # 'if let value = %sRaw {' % ( value_name, ), # ' %s = NSNumber(value: value)' % ( value_name, ), # '}', ] elif self._objc_type == "NSDate *": # Persist dates as NSTimeInterval timeIntervalSince1970. interval_name = "%sInterval" % (str(value_name),) if did_force_optional: serialized_statements = [ "guard let %s: Double = %s else {" % ( interval_name, value_expr, ), " throw SDSError.missingRequiredField()", "}", ] elif is_optional: serialized_statements = [ "let %s: Double? = %s" % ( interval_name, value_expr, ), ] else: serialized_statements = [ "let %s: Double = %s" % ( interval_name, value_expr, ), ] if is_optional: value_statement = ( 'let %s: Date? = SDSDeserialization.optionalDoubleAsDate(%s, name: "%s")' % ( value_name, interval_name, value_name, ) ) else: value_statement = ( 'let %s: Date = SDSDeserialization.requiredDoubleAsDate(%s, name: "%s")' % ( value_name, interval_name, value_name, ) ) return serialized_statements + [ value_statement, ] else: value_statement = "let %s: %s = %s" % ( value_name, initializer_param_type, value_expr, ) return [ value_statement, ] def serialize_record_invocation( self, property, value_name, is_optional, did_force_optional ): value_expr = value_name if property.field_override_serialize_record_invocation() is not None: return property.field_override_serialize_record_invocation() % (value_expr,) elif self.is_codable: pass elif self.should_use_blob: # blob_name = '%sSerialized' % ( str(value_name), ) if is_optional or did_force_optional: return "optionalArchive(%s)" % (value_expr,) else: return "requiredArchive(%s)" % (value_expr,) elif self._objc_type == "NSDate *": if is_optional or did_force_optional: return "archiveOptionalDate(%s)" % (value_expr,) else: return "archiveDate(%s)" % (value_expr,) elif self._objc_type == "NSNumber *": # elif self.is_numeric(): conversion_map = { "Int8": "int8Value", "UInt8": "uint8Value", "Int16": "int16Value", "UInt16": "uint16Value", "Int32": "int32Value", "UInt32": "uint32Value", "Int64": "int64Value", "UInt64": "uint64Value", "Float": "floatValue", "Double": "doubleValue", "Bool": "boolValue", "Int": "intValue", "UInt": "uintValue", } conversion_method = conversion_map[self.swift_type()] if conversion_method is None: fail("Could not convert:", self.swift_type()) serialization_conversion = "{ $0.%s }" % (conversion_method,) if is_optional or did_force_optional: return "archiveOptionalNSNumber(%s, conversion: %s)" % ( value_expr, serialization_conversion, ) else: return "archiveNSNumber(%s, conversion: %s)" % ( value_expr, serialization_conversion, ) return value_expr def record_field_type(self, value_name): # Special case this oddball type. if self.field_override_record_swift_type is not None: return self.field_override_record_swift_type elif self.is_codable: pass elif self.should_use_blob: return "Data" return self.swift_type() class ParsedProperty: def __init__(self, json_dict): self.name = json_dict.get("name") self.is_optional = json_dict.get("is_optional") self.objc_type = json_dict.get("objc_type") self.class_name = json_dict.get("class_name") self.swift_type = None def try_to_convert_objc_primitive_to_swift(self, objc_type, unpack_nsnumber=True): if objc_type is None: fail("Missing type") elif objc_type == "NSString *": return "String" elif objc_type == "NSDate *": # Persist dates as NSTimeInterval timeIntervalSince1970. return "Double" elif objc_type == "NSData *": return "Data" elif objc_type == "BOOL": return "Bool" elif objc_type == "NSInteger": return "Int" elif objc_type == "NSUInteger": return "UInt" elif objc_type == "int32_t": return "Int32" elif objc_type == "uint32_t": return "UInt32" elif objc_type == "int64_t": return "Int64" elif objc_type == "long long": return "Int64" elif objc_type == "unsigned long long": return "UInt64" elif objc_type == "uint64_t": return "UInt64" elif objc_type == "unsigned long": return "UInt64" elif objc_type == "unsigned int": return "UInt32" elif objc_type == "double": return "Double" elif objc_type == "float": return "Float" elif objc_type == "CGFloat": return "Double" elif objc_type == "NSNumber *": if unpack_nsnumber: return swift_type_for_nsnumber(self) else: return "NSNumber" else: return None # NOTE: This method recurses to unpack types like: NSArray *> * def convert_objc_class_to_swift(self, objc_type, unpack_nsnumber=True): if objc_type == "id": return "AnyObject" elif not objc_type.endswith(" *"): return None swift_primitive = self.try_to_convert_objc_primitive_to_swift( objc_type, unpack_nsnumber=unpack_nsnumber ) if swift_primitive is not None: return swift_primitive array_match = re.search(r"^NS(Mutable)?Array<(.+)> \*$", objc_type) if array_match is not None: split = array_match.group(2) return ( "[" + self.convert_objc_class_to_swift(split, unpack_nsnumber=False) + "]" ) dict_match = re.search(r"^NS(Mutable)?Dictionary<(.+),(.+)> \*$", objc_type) if dict_match is not None: split1 = dict_match.group(2).strip() split2 = dict_match.group(3).strip() return ( "[" + self.convert_objc_class_to_swift(split1, unpack_nsnumber=False) + ": " + self.convert_objc_class_to_swift(split2, unpack_nsnumber=False) + "]" ) ordered_set_match = re.search(r"^NSOrderedSet<(.+)> \*$", objc_type) if ordered_set_match is not None: # swift has no primitive for ordered set, so we lose the element type return "NSOrderedSet" swift_type = objc_type[: -len(" *")] if "<" in swift_type or "{" in swift_type or "*" in swift_type: fail("Unexpected type:", objc_type) return swift_type def try_to_convert_objc_type_to_type_info(self): objc_type = self.objc_type if objc_type is None: fail("Missing type") elif self.field_override_swift_type(): return TypeInfo( self.field_override_swift_type(), objc_type, should_use_blob=self.field_override_should_use_blob(), is_enum=self.field_override_is_enum(), field_override_column_type=self.field_override_column_type(), field_override_record_swift_type=self.field_override_record_swift_type(), ) elif objc_type in enum_type_map: enum_type = objc_type return TypeInfo(enum_type, objc_type, is_enum=True) elif objc_type.startswith("enum "): enum_type = objc_type[len("enum ") :] return TypeInfo(enum_type, objc_type, is_enum=True) swift_primitive = self.try_to_convert_objc_primitive_to_swift(objc_type) if swift_primitive is not None: return TypeInfo(swift_primitive, objc_type) if objc_type in ( "struct CGSize", "struct CGRect", "struct CGPoint", ): objc_type = objc_type[len("struct ") :] swift_type = objc_type return TypeInfo( swift_type, objc_type, should_use_blob=True, is_codable=USE_CODABLE_FOR_PRIMITIVES, ) swift_type = self.convert_objc_class_to_swift(self.objc_type) if swift_type is not None: if self.is_objc_type_codable(objc_type): return TypeInfo( swift_type, objc_type, should_use_blob=True, is_codable=False ) return TypeInfo( swift_type, objc_type, should_use_blob=True, is_codable=False ) fail("Unknown type(3):", self.class_name, self.objc_type, self.name) # NOTE: This method recurses to unpack types like: NSArray *> * def is_objc_type_codable(self, objc_type): if not USE_CODABLE_FOR_PRIMITIVES: return False if objc_type in ("NSString *",): return True elif objc_type in ( "struct CGSize", "struct CGRect", "struct CGPoint", ): return True elif self.field_override_is_objc_codable() is not None: return self.field_override_is_objc_codable() elif objc_type in enum_type_map: return True elif objc_type.startswith("enum "): return True if not USE_CODABLE_FOR_NONPRIMITIVES: return False array_match = re.search(r"^NS(Mutable)?Array<(.+)> \*$", objc_type) if array_match is not None: split = array_match.group(2) return self.is_objc_type_codable(split) dict_match = re.search(r"^NS(Mutable)?Dictionary<(.+),(.+)> \*$", objc_type) if dict_match is not None: split1 = dict_match.group(2).strip() split2 = dict_match.group(3).strip() return self.is_objc_type_codable(split1) and self.is_objc_type_codable( split2 ) return False def field_override_swift_type(self): return self._field_override("swift_type") def field_override_is_objc_codable(self): return self._field_override("is_objc_codable") def field_override_is_enum(self): return self._field_override("is_enum") def field_override_column_type(self): return self._field_override("column_type") def field_override_record_swift_type(self): return self._field_override("record_swift_type") def field_override_serialize_record_invocation(self): return self._field_override("serialize_record_invocation") def field_override_should_use_blob(self): return self._field_override("should_use_blob") def field_override_objc_initializer_type(self): return self._field_override("objc_initializer_type") def _field_override(self, override_field): manually_typed_fields = configuration_json.get("manually_typed_fields") if manually_typed_fields is None: fail("Configuration JSON is missing manually_typed_fields") key = self.class_name + "." + self.name if key in manually_typed_fields: return manually_typed_fields[key][override_field] else: return None def type_info(self): if self.swift_type is not None: should_use_blob = ( self.swift_type.startswith("[") or self.swift_type.startswith("{") or is_swift_class_name(self.swift_type) ) return TypeInfo( self.swift_type, objc_type, should_use_blob=should_use_blob, is_codable=USE_CODABLE_FOR_PRIMITIVES, field_override_column_type=self.field_override_column_type, ) return self.try_to_convert_objc_type_to_type_info() def swift_type_safe(self): return self.type_info().swift_type() def objc_type_safe(self): if self.field_override_objc_initializer_type() is not None: return self.field_override_objc_initializer_type() result = self.type_info().objc_type() if result.startswith("enum "): result = result[len("enum ") :] return result # if self.objc_type is None: # fail("Don't know Obj-C type for:", self.name) # return self.objc_type def database_column_type(self): return self.type_info().database_column_type(self.name) def should_ignore_property(self): return should_ignore_property(self) def has_aliased_column_name(self): return aliased_column_name_for_property(self) is not None def deserialize_record_invocation(self, value_name, did_force_optional): return self.type_info().deserialize_record_invocation( self, value_name, self.is_optional, did_force_optional ) def deep_copy_record_invocation(self, value_name, did_force_optional): swift_type = self.swift_type_safe() objc_type = self.objc_type_safe() is_optional = self.is_optional model_accessor = accessor_name_for_property(self) initializer_param_type = swift_type if is_optional: initializer_param_type = initializer_param_type + "?" simple_type_map = { "NSString *": "String", "NSNumber *": "NSNumber", "NSDate *": "Date", "NSData *": "Data", "CGSize": "CGSize", "CGRect": "CGRect", "CGPoint": "CGPoint", } if objc_type in simple_type_map: initializer_param_type = simple_type_map[objc_type] if is_optional: initializer_param_type += "?" return [ "let %s: %s = modelToCopy.%s" % ( value_name, initializer_param_type, model_accessor, ), ] can_shallow_copy = False if self.type_info().is_numeric(): can_shallow_copy = True elif self.is_enum(): can_shallow_copy = True if can_shallow_copy: return [ "let %s: %s = modelToCopy.%s" % ( value_name, initializer_param_type, model_accessor, ), ] initializer_param_type = initializer_param_type.replace("AnyObject", "Any") if is_optional: return [ "let %s: %s" % ( value_name, initializer_param_type, ), "if let %sForCopy = modelToCopy.%s {" % ( value_name, model_accessor, ), " %s = try DeepCopies.deepCopy(%sForCopy)" % ( value_name, value_name, ), "} else {", " %s = nil" % (value_name,), "}", ] else: return [ "let %s: %s = try DeepCopies.deepCopy(modelToCopy.%s)" % ( value_name, initializer_param_type, model_accessor, ), ] fail( "I don't know how to deep copy this type: %s / %s" % (objc_type, swift_type) ) def possible_class_type_for_property(self): swift_type = self.swift_type_safe() if swift_type in global_class_map: return global_class_map[swift_type] objc_type = self.objc_type_safe() if objc_type.endswith(" *"): objc_type = objc_type[:-2] if objc_type in global_class_map: return global_class_map[objc_type] return None def serialize_record_invocation(self, value_name, did_force_optional): return self.type_info().serialize_record_invocation( self, value_name, self.is_optional, did_force_optional ) def record_field_type(self): return self.type_info().record_field_type(self.name) def is_enum(self): return self.type_info().is_enum def swift_identifier(self): return to_swift_identifier_name(self.name) def column_name(self): aliased_column_name = aliased_column_name_for_property(self) if aliased_column_name is not None: return aliased_column_name custom_column_name = custom_column_name_for_property(self) if custom_column_name is not None: return custom_column_name else: return self.swift_identifier() def ows_getoutput(cmd): proc = subprocess.Popen( cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, ) stdout, stderr = proc.communicate() return proc.returncode, stdout, stderr # ---- Parsing def properties_and_inherited_properties(clazz): result = [] if clazz.super_class_name in global_class_map: super_class = global_class_map[clazz.super_class_name] result.extend(properties_and_inherited_properties(super_class)) result.extend(clazz.properties()) return result def generate_swift_extensions_for_model(clazz): if not clazz.should_generate_extensions(): return has_sds_superclass = clazz.has_sds_superclass() has_remove_methods = clazz.name not in ("TSThread", "TSInteraction") has_grdb_serializer = clazz.name in ("TSInteraction") swift_filename = os.path.basename(clazz.filepath) swift_filename = swift_filename[: swift_filename.find(".")] + "+SDS.swift" swift_filepath = os.path.join(os.path.dirname(clazz.filepath), swift_filename) record_type = get_record_type(clazz) # TODO: We'll need to import SignalServiceKit for non-SSK models. swift_body = """// // Copyright 2022 Signal Messenger, LLC // SPDX-License-Identifier: AGPL-3.0-only // import Foundation %simport GRDB // NOTE: This file is generated by %s. // Do not manually edit it, instead run `sds_codegen.sh`. """ % ( "" if has_sds_superclass else "public ", sds_common.pretty_module_path(__file__), ) if not has_sds_superclass: # If a property has a custom column source, we don't redundantly create a column for that column base_properties = [ property for property in clazz.properties() if not property.has_aliased_column_name() ] # If a property has a custom column source, we don't redundantly create a column for that column subclass_properties = [ property for property in clazz.database_subclass_properties() if not property.has_aliased_column_name() ] swift_body += """ // MARK: - Record """ record_name = clazz.record_name() swift_body += """ public struct %s: SDSRecord { public weak var delegate: SDSRecordDelegate? public var tableMetadata: SDSTableMetadata { %sSerializer.table } public static var databaseTableName: String { %sSerializer.table.tableName } public var id: Int64? // This defines all of the columns used in the table // where this model (and any subclasses) are persisted. public let recordType: SDSRecordType? public let uniqueId: String """ % ( record_name, str(clazz.name), str(clazz.name), ) def write_record_property(property, force_optional=False): column_name = property.swift_identifier() record_field_type = property.record_field_type() is_optional = property.is_optional or force_optional optional_split = "?" if is_optional else "" custom_column_name = custom_column_name_for_property(property) if custom_column_name is not None: column_name = custom_column_name return """ public let %s: %s%s """ % ( str(column_name), record_field_type, optional_split, ) record_properties = clazz.sorted_record_properties() # Declare the model properties in the record. if len(record_properties) > 0: swift_body += "\n // Properties \n" for property in record_properties: swift_body += write_record_property( property, force_optional=property.force_optional ) sds_properties = [ ParsedProperty( { "name": "id", "is_optional": False, "objc_type": "NSInteger", "class_name": clazz.name, } ), ParsedProperty( { "name": "recordType", "is_optional": False, "objc_type": "NSUInteger", "class_name": clazz.name, } ), ParsedProperty( { "name": "uniqueId", "is_optional": False, "objc_type": "NSString *", "class_name": clazz.name, } ), ] # We use the pre-sorted collection record_properties so that # we use the correct property order when generating: # # * CodingKeys # * init(row: Row) # * The table/column metadata. persisted_properties = sds_properties + record_properties swift_body += """ public enum CodingKeys: String, CodingKey, ColumnExpression, CaseIterable { """ for property in persisted_properties: custom_column_name = custom_column_name_for_property(property) was_property_renamed = was_property_renamed_for_property(property) if custom_column_name is not None: if was_property_renamed: swift_body += """ case %s """ % ( custom_column_name, ) else: swift_body += """ case %s = "%s" """ % ( custom_column_name, property.swift_identifier(), ) else: swift_body += """ case %s """ % ( property.swift_identifier(), ) swift_body += """ } """ swift_body += """ public static func columnName(_ column: %s.CodingKeys, fullyQualified: Bool = false) -> String { fullyQualified ? "\\(databaseTableName).\\(column.rawValue)" : column.rawValue } public func didInsert(with rowID: Int64, for column: String?) { guard let delegate = delegate else { owsFailDebug("Missing delegate.") return } delegate.updateRowId(rowID) } } """ % ( record_name, ) swift_body += """ // MARK: - Row Initializer public extension %s { static var databaseSelection: [SQLSelectable] { CodingKeys.allCases } init(row: Row) {""" % ( record_name ) for index, property in enumerate(persisted_properties): type_info = property.type_info() property_name = property.column_name() swift_type = type_info.swift_type() did_force_optional = property.name not in ["mentionNotificationMode", "storyViewMode"] did_force_optional = did_force_optional and type_info.is_enum if property_name == "recordType": # recordType is an enum, but its property info here doesn't # reflect that, so special-case it. swift_body += """ %s = row[%s].flatMap { SDSRecordType(rawValue: $0) }""" % ( property_name, index, ) elif did_force_optional: swift_body += """ %s = row[%s].flatMap { %s(rawValue: $0) }""" % ( property_name, index, swift_type, ) else: swift_body += """ %s = row[%s]""" % ( property_name, index, ) swift_body += """ } } """ swift_body += """ // MARK: - StringInterpolation public extension String.StringInterpolation { mutating func appendInterpolation(%(record_identifier)sColumn column: %(record_name)s.CodingKeys) { appendLiteral(%(record_name)s.columnName(column)) } mutating func appendInterpolation(%(record_identifier)sColumnFullyQualified column: %(record_name)s.CodingKeys) { appendLiteral(%(record_name)s.columnName(column, fullyQualified: true)) } } """ % { "record_identifier": record_identifier(clazz.name), "record_name": record_name, } # TODO: Rework metadata to not include, for example, columns, column indices. swift_body += """ // MARK: - Deserialization extension %s { // This method defines how to deserialize a model, given a // database row. The recordType column is used to determine // the corresponding model class. class func fromRecord(_ record: %s) throws -> %s { """ % ( str(clazz.name), record_name, str(clazz.name), ) swift_body += """ guard let recordId = record.id else { throw SDSError.missingRequiredField(fieldName: "id") } guard let recordType = record.recordType else { throw SDSError.missingRequiredField(fieldName: "recordType") } switch recordType { """ deserialize_classes = all_descendents_of_class(clazz) + [clazz] deserialize_classes.sort(key=lambda value: value.name) for deserialize_class in deserialize_classes: if should_ignore_class(deserialize_class): continue initializer_params = [] objc_initializer_params = [] objc_super_initializer_args = [] objc_initializer_assigns = [] deserialize_record_type = get_record_type_enum_name(deserialize_class.name) swift_body += """ case .%s: """ % ( str(deserialize_record_type), ) swift_body += """ let uniqueId: String = record.uniqueId """ base_property_names = set() for property in base_properties: base_property_names.add(property.name) deserialize_properties = properties_and_inherited_properties( deserialize_class ) has_local_properties = False for property in deserialize_properties: value_name = "%s" % property.name if property.name not in ("uniqueId",): did_force_optional = property.name not in ["mentionNotificationMode", "storyViewMode"] did_force_optional = did_force_optional and property.name not in base_property_names did_force_optional = did_force_optional and not property.is_optional did_force_optional = did_force_optional or property.type_info().is_enum for statement in property.deserialize_record_invocation( value_name, did_force_optional ): swift_body += " %s\n" % (str(statement),) initializer_params.append( "%s: %s" % ( str(property.name), value_name, ) ) objc_initializer_type = str(property.objc_type_safe()) if objc_initializer_type.startswith("NSMutable"): objc_initializer_type = ( "NS" + objc_initializer_type[len("NSMutable") :] ) if property.is_optional: objc_initializer_type = "nullable " + objc_initializer_type objc_initializer_params.append( "%s:(%s)%s" % ( str(property.name), objc_initializer_type, str(property.name), ) ) is_superclass_property = property.class_name != deserialize_class.name if is_superclass_property: objc_super_initializer_args.append( "%s:%s" % ( str(property.name), str(property.name), ) ) else: has_local_properties = True if str(property.objc_type_safe()).startswith("NSMutableArray"): objc_initializer_assigns.append( "_%s = %s ? [%s mutableCopy] : [NSMutableArray new];" % ( str(property.name), str(property.name), str(property.name), ) ) elif str(property.objc_type_safe()).startswith( "NSMutableDictionary" ): objc_initializer_assigns.append( "_%s = %s ? [%s mutableCopy] : [NSMutableDictionary new];" % ( str(property.name), str(property.name), str(property.name), ) ) elif ( deserialize_class.name == "TSIncomingMessage" and property.name in ("authorUUID", "authorPhoneNumber") ): pass else: objc_initializer_assigns.append( "_%s = %s;" % ( str(property.name), str(property.name), ) ) # --- Initializer Snippets h_snippet = "" h_snippet += """ // clang-format off - (instancetype)initWithGrdbId:(int64_t)grdbId uniqueId:(NSString *)uniqueId """ for objc_initializer_param in objc_initializer_params[1:]: alignment = max( 0, len("- (instancetype)initWithUniqueId") - objc_initializer_param.index(":"), ) h_snippet += (" " * alignment) + objc_initializer_param + "\n" h_snippet += ( "NS_DESIGNATED_INITIALIZER NS_SWIFT_NAME(init(grdbId:%s:));\n" % ":".join([str(property.name) for property in deserialize_properties]) ) h_snippet += """ // clang-format on """ m_snippet = "" m_snippet += """ // clang-format off - (instancetype)initWithGrdbId:(int64_t)grdbId uniqueId:(NSString *)uniqueId """ for objc_initializer_param in objc_initializer_params[1:]: alignment = max( 0, len("- (instancetype)initWithUniqueId") - objc_initializer_param.index(":"), ) m_snippet += (" " * alignment) + objc_initializer_param + "\n" if len(objc_super_initializer_args) == 1: suffix = "];" else: suffix = "" m_snippet += """{ self = [super initWithGrdbId:grdbId uniqueId:uniqueId%s """ % ( suffix ) for index, objc_super_initializer_arg in enumerate( objc_super_initializer_args[1:] ): alignment = max( 0, len(" self = [super initWithUniqueId") - objc_super_initializer_arg.index(":"), ) if index == len(objc_super_initializer_args) - 2: suffix = "];" else: suffix = "" m_snippet += ( (" " * alignment) + objc_super_initializer_arg + suffix + "\n" ) m_snippet += """ if (!self) { return self; } """ if deserialize_class.name == "TSIncomingMessage": m_snippet += """ if (authorUUID != nil) { _authorUUID = authorUUID; } else if (authorPhoneNumber != nil) { _authorPhoneNumber = authorPhoneNumber; } """ for objc_initializer_assign in objc_initializer_assigns: m_snippet += (" " * 4) + objc_initializer_assign + "\n" if deserialize_class.finalize_method_name is not None: m_snippet += """ [self %s]; """ % ( str(deserialize_class.finalize_method_name), ) m_snippet += """ return self; } // clang-format on """ # Skip initializer generation for classes without any properties. if not has_local_properties: h_snippet = "" m_snippet = "" if deserialize_class.filepath.endswith(".m"): m_filepath = deserialize_class.filepath h_filepath = m_filepath[:-2] + ".h" update_objc_snippet(h_filepath, h_snippet) update_objc_snippet(m_filepath, m_snippet) swift_body += """ """ # --- Invoke Initializer initializer_invocation = " return %s(" % str( deserialize_class.name ) swift_body += initializer_invocation initializer_params = [ "grdbId: recordId", ] + initializer_params swift_body += (",\n" + " " * len(initializer_invocation)).join( initializer_params ) swift_body += ")" swift_body += """ """ # TODO: We could generate a comment with the Obj-C (or Swift) model initializer # that this deserialization code expects. swift_body += """ default: owsFailDebug("Unexpected record type: \\(recordType)") throw SDSError.invalidValue() """ swift_body += """ } """ swift_body += """ } """ swift_body += """} """ # TODO: Remove the serialization glue below. if not has_sds_superclass: swift_body += """ // MARK: - SDSModel extension %s: SDSModel { public var serializer: SDSSerializer { // Any subclass can be cast to it's superclass, // so the order of this switch statement matters. // We need to do a "depth first" search by type. switch self {""" % str( clazz.name ) for subclass in reversed(all_descendents_of_class(clazz)): if should_ignore_class(subclass): continue swift_body += """ case let model as %s: assert(type(of: model) == %s.self) return %sSerializer(model: model)""" % ( str(subclass.name), str(subclass.name), str(subclass.name), ) swift_body += """ default: return %sSerializer(model: self) } } public func asRecord() -> SDSRecord { serializer.asRecord() } public var sdsTableName: String { %s.databaseTableName } public static var table: SDSTableMetadata { %sSerializer.table } } """ % ( str(clazz.name), record_name, str(clazz.name), ) if not has_sds_superclass: swift_body += """ // MARK: - DeepCopyable extension %(class_name)s: DeepCopyable { public func deepCopy() throws -> AnyObject { guard let id = self.grdbId?.int64Value else { throw OWSAssertionError("Model missing grdbId.") } // Any subclass can be cast to its superclass, so the order of these if // statements matters. We need to do a "depth first" search by type. """ % { "class_name": str(clazz.name) } classes_to_copy = list(reversed(all_descendents_of_class(clazz))) + [ clazz, ] for class_to_copy in classes_to_copy: if should_ignore_class(class_to_copy): continue if class_to_copy == clazz: swift_body += """ do { let modelToCopy = self assert(type(of: modelToCopy) == %(class_name)s.self) """ % { "class_name": str(class_to_copy.name) } else: swift_body += """ if let modelToCopy = self as? %(class_name)s { assert(type(of: modelToCopy) == %(class_name)s.self) """ % { "class_name": str(class_to_copy.name) } initializer_params = [] base_property_names = set() for property in base_properties: base_property_names.add(property.name) deserialize_properties = properties_and_inherited_properties(class_to_copy) for property in deserialize_properties: value_name = "%s" % property.name did_force_optional = property.name not in ["mentionNotificationMode", "storyViewMode"] did_force_optional = did_force_optional and property.name not in base_property_names did_force_optional = did_force_optional and not property.is_optional did_force_optional = did_force_optional or property.type_info().is_enum for statement in property.deep_copy_record_invocation( value_name, did_force_optional ): swift_body += " %s\n" % (str(statement),) initializer_params.append( "%s: %s" % ( str(property.name), value_name, ) ) swift_body += """ """ # --- Invoke Initializer initializer_invocation = " return %s(" % str(class_to_copy.name) swift_body += initializer_invocation initializer_params = [ "grdbId: id", ] + initializer_params swift_body += (",\n" + " " * len(initializer_invocation)).join( initializer_params ) swift_body += ")" swift_body += """ } """ swift_body += """ } } """ if has_grdb_serializer: swift_body += """ // MARK: - Table Metadata extension %sRecord { // This defines all of the columns used in the table // where this model (and any subclasses) are persisted. internal func asValues() -> [DatabaseValueConvertible?] { return [ """ % str( remove_prefix_from_class_name(clazz.name) ) def write_grdb_column_metadata(metadata): return """ %s, """ % ( str(metadata) ) for property in sds_properties: column_name = property.column_name() if column_name == "recordType" or property.type_info().is_enum: swift_body += write_grdb_column_metadata("%s?.rawValue" % (column_name)) elif property.name != "id": swift_body += write_grdb_column_metadata(column_name) for property in record_properties: column_name = property.column_name() if property.type_info().is_enum: swift_body += write_grdb_column_metadata("%s?.rawValue" % (column_name)) else: swift_body += write_grdb_column_metadata(column_name) swift_body += """ ] } internal func asArguments() -> StatementArguments { return StatementArguments(asValues()) } } """ if not has_sds_superclass: swift_body += """ // MARK: - Table Metadata extension %sSerializer { // This defines all of the columns used in the table // where this model (and any subclasses) are persisted. """ % str( clazz.name ) # Eventually we need a (persistent?) mechanism for guaranteeing # consistency of column ordering, that is robust to schema # changes, class hierarchy changes, etc. column_property_names = [] def write_column_metadata(property, force_optional=False): column_name = property.swift_identifier() column_property_names.append(column_name) is_optional = property.is_optional or force_optional optional_split = ", isOptional: true" if is_optional else "" is_unique = column_name == str("uniqueId") is_unique_split = ", isUnique: true" if is_unique else "" database_column_type = property.database_column_type() if property.name == "id": database_column_type = ".primaryKey" # TODO: Use skipSelect. return """ static var %sColumn: SDSColumnMetadata { SDSColumnMetadata(columnName: "%s", columnType: %s%s%s) } """ % ( str(column_name), str(column_name), database_column_type, optional_split, is_unique_split, ) for property in sds_properties: swift_body += write_column_metadata(property) if len(record_properties) > 0: swift_body += " // Properties \n" for property in record_properties: swift_body += write_column_metadata( property, force_optional=property.force_optional ) database_table_name = "model_%s" % str(clazz.name) swift_body += """ public static var table: SDSTableMetadata { SDSTableMetadata( tableName: "%s", columns: [ """ % ( database_table_name, ) swift_body += "\n".join( [ " %sColumn," % str(column_property_name) for column_property_name in column_property_names ] ) swift_body += """ ] ) } } """ # ---- Fetch ---- ignore_cache = "" if cache_get_code_for_class(clazz) is not None: ignore_cache = ", ignoreCache: true" swift_body += """ // MARK: - Save/Remove/Update @objc public extension %(class_name)s { func anyInsert(transaction: DBWriteTransaction) { sdsSave(saveMode: .insert, transaction: transaction) } // Avoid this method whenever feasible. // // If the record has previously been saved, this method does an overwriting // update of the corresponding row, otherwise if it's a new record, this // method inserts a new row. // // For performance, when possible, you should explicitly specify whether // you are inserting or updating rather than calling this method. func anyUpsert(transaction: DBWriteTransaction) { let isInserting: Bool if %(class_name)s.anyFetch(uniqueId: uniqueId, transaction: transaction) != nil { isInserting = false } else { isInserting = true } sdsSave(saveMode: isInserting ? .insert : .update, transaction: transaction) } // This method is used by "updateWith..." methods. // // This model may be updated from many threads. We don't want to save // our local copy (this instance) since it may be out of date. We also // want to avoid re-saving a model that has been deleted. Therefore, we // use "updateWith..." methods to: // // a) Update a property of this instance. // b) If a copy of this model exists in the database, load an up-to-date copy, // and update and save that copy. // b) If a copy of this model _DOES NOT_ exist in the database, do _NOT_ save // this local instance. // // After "updateWith...": // // a) Any copy of this model in the database will have been updated. // b) The local property on this instance will always have been updated. // c) Other properties on this instance may be out of date. // // All mutable properties of this class have been made read-only to // prevent accidentally modifying them directly. // // This isn't a perfect arrangement, but in practice this will prevent // data loss and will resolve all known issues. func anyUpdate(transaction: DBWriteTransaction, block: (%(class_name)s) -> Void) { block(self) // If it's not saved, we don't expect to find it in the database, and we // won't save any changes we make back into the database. guard shouldBeSaved else { return } guard let dbCopy = type(of: self).anyFetch(uniqueId: uniqueId, transaction: transaction%(ignore_cache)s) else { return } // Don't apply the block twice to the same instance. // It's at least unnecessary and actually wrong for some blocks. // e.g. `block: { $0 in $0.someField++ }` if dbCopy !== self { block(dbCopy) } dbCopy.sdsSave(saveMode: .update, transaction: transaction) } // This method is an alternative to `anyUpdate(transaction:block:)` methods. // // We should generally use `anyUpdate` to ensure we're not unintentionally // clobbering other columns in the database when another concurrent update // has occurred. // // There are cases when this doesn't make sense, e.g. when we know we've // just loaded the model in the same transaction. In those cases it is // safe and faster to do a "overwriting" update func anyOverwritingUpdate(transaction: DBWriteTransaction) { sdsSave(saveMode: .update, transaction: transaction) } """ % { "class_name": str(clazz.name), "ignore_cache": ignore_cache, } if has_remove_methods: swift_body += """ func anyRemove(transaction: DBWriteTransaction) { sdsRemove(transaction: transaction) } """ swift_body += """} """ # ---- Cursor ---- swift_body += """ // MARK: - %sCursor @objc public class %sCursor: NSObject, SDSCursor { private let transaction: DBReadTransaction private let cursor: RecordCursor<%s>? init(transaction: DBReadTransaction, cursor: RecordCursor<%s>?) { self.transaction = transaction self.cursor = cursor } public func next() throws -> %s? { guard let cursor = cursor else { return nil } guard let record = try cursor.next() else { return nil }""" % ( str(clazz.name), str(clazz.name), record_name, record_name, str(clazz.name), ) cache_code = cache_set_code_for_class(clazz) if cache_code is not None: swift_body += """ let value = try %s.fromRecord(record) %s(value, transaction: transaction) return value""" % ( str(clazz.name), cache_code, ) else: swift_body += """ return try %s.fromRecord(record)""" % ( str(clazz.name), ) swift_body += """ } public func all() throws -> [%s] { var result = [%s]() while true { guard let model = try next() else { break } result.append(model) } return result } } """ % ( str(clazz.name), str(clazz.name), ) # ---- Fetch ---- swift_body += """ // MARK: - Obj-C Fetch @objc public extension %(class_name)s { @nonobjc class func grdbFetchCursor(transaction: DBReadTransaction) -> %(class_name)sCursor { let database = transaction.database do { let cursor = try %(record_name)s.fetchCursor(database) return %(class_name)sCursor(transaction: transaction, cursor: cursor) } catch { DatabaseCorruptionState.flagDatabaseReadCorruptionIfNecessary( userDefaults: CurrentAppContext().appUserDefaults(), error: error ) owsFailDebug("Read failed: \\(error)") return %(class_name)sCursor(transaction: transaction, cursor: nil) } } """ % { "class_name": str(clazz.name), "record_name": record_name, } swift_body += """ // Fetches a single model by "unique id". class func anyFetch(uniqueId: String, transaction: DBReadTransaction) -> %(class_name)s? { assert(!uniqueId.isEmpty) """ % { "class_name": str(clazz.name), "record_name": record_name, "record_identifier": record_identifier(clazz.name), } cache_code = cache_get_code_for_class(clazz) if cache_code is not None: swift_body += """ return anyFetch(uniqueId: uniqueId, transaction: transaction, ignoreCache: false) } // Fetches a single model by "unique id". class func anyFetch(uniqueId: String, transaction: DBReadTransaction, ignoreCache: Bool) -> %(class_name)s? { assert(!uniqueId.isEmpty) if !ignoreCache, let cachedCopy = %(cache_code)s { return cachedCopy } """ % { "class_name": str(clazz.name), "cache_code": str(cache_code), } swift_body += """ let sql = "SELECT * FROM \\(%(record_name)s.databaseTableName) WHERE \\(%(record_identifier)sColumn: .uniqueId) = ?" return grdbFetchOne(sql: sql, arguments: [uniqueId], transaction: transaction) } """ % { "record_name": record_name, "record_identifier": record_identifier(clazz.name), } swift_body += """ // Traverses all records. // Records are not visited in any particular order. class func anyEnumerate( transaction: DBReadTransaction, block: (%s, UnsafeMutablePointer) -> Void ) { anyEnumerate(transaction: transaction, batched: false, block: block) } // Traverses all records. // Records are not visited in any particular order. class func anyEnumerate( transaction: DBReadTransaction, batched: Bool = false, block: (%s, UnsafeMutablePointer) -> Void ) { let batchSize = batched ? Batching.kDefaultBatchSize : 0 anyEnumerate(transaction: transaction, batchSize: batchSize, block: block) } // Traverses all records. // Records are not visited in any particular order. // // If batchSize > 0, the enumeration is performed in autoreleased batches. class func anyEnumerate( transaction: DBReadTransaction, batchSize: UInt, block: (%s, UnsafeMutablePointer) -> Void ) { let cursor = %s.grdbFetchCursor(transaction: transaction) Batching.loop(batchSize: batchSize, loopBlock: { stop in do { guard let value = try cursor.next() else { stop.pointee = true return } block(value, stop) } catch let error { owsFailDebug("Couldn't fetch model: \\(error)") } }) } """ % ( (str(clazz.name),) * 4 ) swift_body += ''' // Traverses all records' unique ids. // Records are not visited in any particular order. class func anyEnumerateUniqueIds( transaction: DBReadTransaction, block: (String, UnsafeMutablePointer) -> Void ) { anyEnumerateUniqueIds(transaction: transaction, batched: false, block: block) } // Traverses all records' unique ids. // Records are not visited in any particular order. class func anyEnumerateUniqueIds( transaction: DBReadTransaction, batched: Bool = false, block: (String, UnsafeMutablePointer) -> Void ) { let batchSize = batched ? Batching.kDefaultBatchSize : 0 anyEnumerateUniqueIds(transaction: transaction, batchSize: batchSize, block: block) } // Traverses all records' unique ids. // Records are not visited in any particular order. // // If batchSize > 0, the enumeration is performed in autoreleased batches. class func anyEnumerateUniqueIds( transaction: DBReadTransaction, batchSize: UInt, block: (String, UnsafeMutablePointer) -> Void ) { grdbEnumerateUniqueIds(transaction: transaction, sql: """ SELECT \\(%sColumn: .uniqueId) FROM \\(%s.databaseTableName) """, batchSize: batchSize, block: block) } ''' % ( record_identifier(clazz.name), record_name, ) swift_body += """ // Does not order the results. class func anyFetchAll(transaction: DBReadTransaction) -> [%s] { var result = [%s]() anyEnumerate(transaction: transaction) { (model, _) in result.append(model) } return result } // Does not order the results. class func anyAllUniqueIds(transaction: DBReadTransaction) -> [String] { var result = [String]() anyEnumerateUniqueIds(transaction: transaction) { (uniqueId, _) in result.append(uniqueId) } return result } """ % ( (str(clazz.name),) * 2 ) # ---- Count ---- swift_body += """ class func anyCount(transaction: DBReadTransaction) -> UInt { return %s.ows_fetchCount(transaction.database) } """ % ( record_name, ) # ---- Exists ---- swift_body += """ class func anyExists( uniqueId: String, transaction: DBReadTransaction ) -> Bool { assert(!uniqueId.isEmpty) let sql = "SELECT EXISTS ( SELECT 1 FROM \\(%s.databaseTableName) WHERE \\(%sColumn: .uniqueId) = ? )" let arguments: StatementArguments = [uniqueId] do { return try Bool.fetchOne(transaction.database, sql: sql, arguments: arguments) ?? false } catch { DatabaseCorruptionState.flagDatabaseReadCorruptionIfNecessary( userDefaults: CurrentAppContext().appUserDefaults(), error: error ) owsFail("Missing instance.") } } } """ % ( record_name, record_identifier(clazz.name), ) # ---- Fetch ---- swift_body += """ // MARK: - Swift Fetch public extension %(class_name)s { class func grdbFetchCursor(sql: String, arguments: StatementArguments = StatementArguments(), transaction: DBReadTransaction) -> %(class_name)sCursor { do { let sqlRequest = SQLRequest(sql: sql, arguments: arguments, cached: true) let cursor = try %(record_name)s.fetchCursor(transaction.database, sqlRequest) return %(class_name)sCursor(transaction: transaction, cursor: cursor) } catch { DatabaseCorruptionState.flagDatabaseReadCorruptionIfNecessary( userDefaults: CurrentAppContext().appUserDefaults(), error: error ) owsFailDebug("Read failed: \\(error)") return %(class_name)sCursor(transaction: transaction, cursor: nil) } } """ % { "class_name": str(clazz.name), "record_name": record_name, } string_interpolation_name = remove_prefix_from_class_name(clazz.name) swift_body += """ class func grdbFetchOne(sql: String, arguments: StatementArguments = StatementArguments(), transaction: DBReadTransaction) -> %s? { assert(!sql.isEmpty) do { let sqlRequest = SQLRequest(sql: sql, arguments: arguments, cached: true) guard let record = try %s.fetchOne(transaction.database, sqlRequest) else { return nil } """ % ( str(clazz.name), record_name, ) cache_code = cache_set_code_for_class(clazz) if cache_code is not None: swift_body += """ let value = try %s.fromRecord(record) %s(value, transaction: transaction) return value""" % ( str(clazz.name), cache_code, ) else: swift_body += """ return try %s.fromRecord(record)""" % ( str(clazz.name), ) swift_body += """ } catch { owsFailDebug("error: \\(error)") return nil } } } """ # ---- Typed Convenience Methods ---- if has_sds_superclass: swift_body += """ // MARK: - Typed Convenience Methods @objc public extension %s { // NOTE: This method will fail if the object has unexpected type. class func anyFetch%s( uniqueId: String, transaction: DBReadTransaction ) -> %s? { assert(!uniqueId.isEmpty) guard let object = anyFetch(uniqueId: uniqueId, transaction: transaction) else { return nil } guard let instance = object as? %s else { owsFailDebug("Object has unexpected type: \\(type(of: object))") return nil } return instance } // NOTE: This method will fail if the object has unexpected type. func anyUpdate%s(transaction: DBWriteTransaction, block: (%s) -> Void) { anyUpdate(transaction: transaction) { (object) in guard let instance = object as? %s else { owsFailDebug("Object has unexpected type: \\(type(of: object))") return } block(instance) } } } """ % ( str(clazz.name), str(remove_prefix_from_class_name(clazz.name)), str(clazz.name), str(clazz.name), str(remove_prefix_from_class_name(clazz.name)), str(clazz.name), str(clazz.name), ) # ---- SDSModel ---- table_superclass = clazz.table_superclass() table_class_name = str(table_superclass.name) has_serializable_superclass = table_superclass.name != clazz.name override_keyword = "" swift_body += """ // MARK: - SDSSerializer // The SDSSerializer protocol specifies how to insert and update the // row that corresponds to this model. class %sSerializer: SDSSerializer { private let model: %s public init(model: %s) { self.model = model } """ % ( str(clazz.name), str(clazz.name), str(clazz.name), ) # --- To Record root_class = clazz.table_superclass() root_record_name = remove_prefix_from_class_name(root_class.name) + "Record" record_id_source = "model.grdbId?.int64Value" if root_class.record_id_source() is not None: record_id_source = ( "model.%(source)s > 0 ? Int64(model.%(source)s) : %(default_source)s" % { "source": root_class.record_id_source(), "default_source": record_id_source, } ) swift_body += """ // MARK: - Record func asRecord() -> SDSRecord { let id: Int64? = %(record_id_source)s let recordType: SDSRecordType = .%(record_type)s let uniqueId: String = model.uniqueId """ % { "record_type": get_record_type_enum_name(clazz.name), "record_id_source": record_id_source, } initializer_args = [ "id", "recordType", "uniqueId", ] inherited_property_map = {} for property in properties_and_inherited_properties(clazz): inherited_property_map[property.column_name()] = property def write_record_property(property, force_optional=False): optional_value = "" if property.column_name() in inherited_property_map: inherited_property = inherited_property_map[property.column_name()] did_force_optional = property.force_optional model_accessor = accessor_name_for_property(inherited_property) value_expr = inherited_property.serialize_record_invocation( "model.%s" % (model_accessor,), did_force_optional ) optional_value = " = %s" % (value_expr,) else: optional_value = " = nil" record_field_type = property.record_field_type() is_optional = property.is_optional or force_optional optional_split = "?" if is_optional else "" initializer_args.append(property.column_name()) return """ let %s: %s%s%s """ % ( str(property.column_name()), record_field_type, optional_split, optional_value, ) root_record_properties = root_class.sorted_record_properties() if len(root_record_properties) > 0: swift_body += "\n // Properties \n" for property in root_record_properties: swift_body += write_record_property( property, force_optional=property.force_optional ) initializer_args = [ "%s: %s" % ( arg, arg, ) for arg in initializer_args ] swift_body += """ return %s(delegate: model, %s) } """ % ( root_record_name, ", ".join(initializer_args), ) swift_body += """} """ print(f"Writing {swift_filename}") swift_body = sds_common.clean_up_generated_swift(swift_body) sds_common.write_text_file_if_changed(swift_filepath, swift_body) def process_class_map(class_map): for clazz in class_map.values(): generate_swift_extensions_for_model(clazz) # ---- Record Type Map record_type_map = {} # It's critical that our "record type" values are consistent, even if we add/remove/rename model classes. # Therefore we persist the mapping of known classes in a JSON file that is under source control. def update_record_type_map(record_type_swift_path, record_type_json_path): record_type_map_filepath = record_type_json_path if os.path.exists(record_type_map_filepath): with open(record_type_map_filepath, "rt") as f: json_string = f.read() json_data = json.loads(json_string) record_type_map.update(json_data) max_record_type = 0 for class_name in record_type_map: if class_name.startswith("#"): continue record_type = record_type_map[class_name] max_record_type = max(max_record_type, record_type) for clazz in global_class_map.values(): if clazz.name not in record_type_map: if not clazz.should_generate_extensions(): continue max_record_type = int(max_record_type) + 1 record_type = max_record_type record_type_map[clazz.name] = record_type record_type_map["#comment"] = ( "NOTE: This file is generated by %s. Do not manually edit it, instead run `sds_codegen.sh`." % (sds_common.pretty_module_path(__file__),) ) json_string = json.dumps(record_type_map, sort_keys=True, indent=4) sds_common.write_text_file_if_changed(record_type_map_filepath, json_string) # TODO: We'll need to import SignalServiceKit for non-SSK classes. swift_body = """// // Copyright 2022 Signal Messenger, LLC // SPDX-License-Identifier: AGPL-3.0-only // import Foundation import GRDB // NOTE: This file is generated by %s. // Do not manually edit it, instead run `sds_codegen.sh`. @objc public enum SDSRecordType: UInt, CaseIterable { """ % ( sds_common.pretty_module_path(__file__), ) record_type_pairs = [] for key in record_type_map.keys(): if key.startswith("#"): # Ignore comments continue enum_name = get_record_type_enum_name(key) record_type_pairs.append((str(enum_name), record_type_map[key])) record_type_pairs.sort(key=lambda value: value[1]) for enum_name, record_type_id in record_type_pairs: swift_body += """ case %s = %s """ % ( enum_name, str(record_type_id), ) swift_body += """} """ swift_body = sds_common.clean_up_generated_swift(swift_body) sds_common.write_text_file_if_changed(record_type_swift_path, swift_body) def get_record_type(clazz): return record_type_map[clazz.name] def remove_prefix_from_class_name(class_name): name = class_name if name.startswith("TS"): name = name[len("TS") :] elif name.startswith("OWS"): name = name[len("OWS") :] elif name.startswith("SSK"): name = name[len("SSK") :] return name def get_record_type_enum_name(class_name): name = remove_prefix_from_class_name(class_name) if name[0].isnumeric(): name = "_" + name return to_swift_identifier_name(name) def record_identifier(class_name): name = remove_prefix_from_class_name(class_name) return to_swift_identifier_name(name) # ---- Column Ordering column_ordering_map = {} has_loaded_column_ordering_map = False # ---- Parsing enum_type_map = {} def objc_type_for_enum(enum_name): if enum_name not in enum_type_map: print("enum_type_map", enum_type_map) fail("Enum has unknown type:", enum_name) enum_type = enum_type_map[enum_name] return enum_type def swift_type_for_enum(enum_name): objc_type = objc_type_for_enum(enum_name) if objc_type == "NSInteger": return "Int" elif objc_type == "NSUInteger": return "UInt" elif objc_type == "int32_t": return "Int32" elif objc_type == "unsigned long long": return "uint64_t" elif objc_type == "unsigned long long": return "UInt64" elif objc_type == "unsigned long": return "UInt64" elif objc_type == "unsigned int": return "UInt" else: fail("Unknown objc type:", objc_type) def parse_sds_json(file_path): with open(file_path, "rt") as f: json_str = f.read() json_data = json.loads(json_str) classes = json_data["classes"] class_map = {} for class_dict in classes: clazz = ParsedClass(class_dict) class_map[clazz.name] = clazz enums = json_data["enums"] enum_type_map.update(enums) return class_map def try_to_parse_file(file_path): filename = os.path.basename(file_path) _, file_extension = os.path.splitext(filename) if filename.endswith(sds_common.SDS_JSON_FILE_EXTENSION): return parse_sds_json(file_path) else: return {} def find_sds_intermediary_files_in_path(path): class_map = {} if os.path.isfile(path): class_map.update(try_to_parse_file(path)) else: for rootdir, dirnames, filenames in os.walk(path): for filename in filenames: file_path = os.path.abspath(os.path.join(rootdir, filename)) class_map.update(try_to_parse_file(file_path)) return class_map def update_subclass_map(): for clazz in global_class_map.values(): if clazz.super_class_name is not None: subclasses = global_subclass_map.get(clazz.super_class_name, []) subclasses.append(clazz) global_subclass_map[clazz.super_class_name] = subclasses def all_descendents_of_class(clazz): result = [] subclasses = global_subclass_map.get(clazz.name, []) subclasses.sort(key=lambda value: value.name) for subclass in subclasses: result.append(subclass) result.extend(all_descendents_of_class(subclass)) return result def is_swift_class_name(swift_type): return global_class_map.get(swift_type) is not None # ---- Config JSON configuration_json = {} def parse_config_json(config_json_path): with open(config_json_path, "rt") as f: json_str = f.read() json_data = json.loads(json_str) global configuration_json configuration_json = json_data # We often use nullable NSNumber * for optional numerics (bool, int, int64, double, etc.). # There's now way to infer which type we're boxing in NSNumber. # Therefore, we need to specify that in the configuration JSON. def swift_type_for_nsnumber(property): nsnumber_types = configuration_json.get("nsnumber_types") if nsnumber_types is None: print("Suggestion: update: %s" % (str(global_args.config_json_path),)) fail("Configuration JSON is missing mapping for properties of type NSNumber.") key = property.class_name + "." + property.name swift_type = nsnumber_types.get(key) if swift_type is None: print("Suggestion: update: %s" % (str(global_args.config_json_path),)) fail( "Configuration JSON is missing mapping for properties of type NSNumber:", key, ) return swift_type # Some properties shouldn't get serialized. # For now, there's just one: TSGroupModel.groupImage which is a UIImage. # We might end up extending the serialization to handle images. # Or we might store these as Data/NSData/blob. # TODO: def should_ignore_property(property): properties_to_ignore = configuration_json.get("properties_to_ignore") if properties_to_ignore is None: fail( "Configuration JSON is missing list of properties to ignore during serialization." ) key = property.class_name + "." + property.name return key in properties_to_ignore def cache_get_code_for_class(clazz): code_map = configuration_json.get("class_cache_get_code") if code_map is None: fail("Configuration JSON is missing dict of class_cache_get_code.") key = clazz.name return code_map.get(key) def cache_set_code_for_class(clazz): code_map = configuration_json.get("class_cache_set_code") if code_map is None: fail("Configuration JSON is missing dict of class_cache_set_code.") key = clazz.name return code_map.get(key) def should_ignore_class(clazz): class_to_skip_serialization = configuration_json.get("class_to_skip_serialization") if class_to_skip_serialization is None: fail( "Configuration JSON is missing list of classes to ignore during serialization." ) if clazz.name in class_to_skip_serialization: return True if clazz.super_class_name is None: return False if not clazz.super_class_name in global_class_map: return False super_clazz = global_class_map[clazz.super_class_name] return should_ignore_class(super_clazz) def accessor_name_for_property(property): custom_accessors = configuration_json.get("custom_accessors") if custom_accessors is None: fail("Configuration JSON is missing list of custom property accessors.") key = property.class_name + "." + property.name return custom_accessors.get(key, property.name) # include_renamed_columns def custom_column_name_for_property(property): custom_column_names = configuration_json.get("custom_column_names") if custom_column_names is None: fail("Configuration JSON is missing list of custom column names.") key = property.class_name + "." + property.name return custom_column_names.get(key) def aliased_column_name_for_property(property): custom_column_names = configuration_json.get("aliased_column_names") if custom_column_names is None: fail("Configuration JSON is missing dict of aliased_column_names.") key = property.class_name + "." + property.name return custom_column_names.get(key) def was_property_renamed_for_property(property): renamed_column_names = configuration_json.get("renamed_column_names") if renamed_column_names is None: fail("Configuration JSON is missing list of renamed column names.") key = property.class_name + "." + property.name return renamed_column_names.get(key) is not None # ---- Config JSON property_order_json = {} def parse_property_order_json(property_order_json_path): with open(property_order_json_path, "rt") as f: json_str = f.read() json_data = json.loads(json_str) global property_order_json property_order_json = json_data # It's critical that our "property order" is consistent, even if we add columns. # Therefore we persist the "property order" for all known properties in a JSON file that is under source control. def update_property_order_json(property_order_json_path): property_order_json["#comment"] = ( "NOTE: This file is generated by %s. Do not manually edit it, instead run `sds_codegen.sh`." % (sds_common.pretty_module_path(__file__),) ) json_string = json.dumps(property_order_json, sort_keys=True, indent=4) sds_common.write_text_file_if_changed(property_order_json_path, json_string) def property_order_key(property, record_name): return record_name + "." + property.name def property_order_for_property(property, record_name): key = property_order_key(property, record_name) result = property_order_json.get(key) return result def set_property_order_for_property(property, record_name, value): key = property_order_key(property, record_name) property_order_json[key] = value if __name__ == "__main__": parser = argparse.ArgumentParser(description="Generate Swift extensions.") parser.add_argument( "--src-path", required=True, help="used to specify a path to process." ) parser.add_argument( "--search-path", required=True, help="used to specify a path to process." ) parser.add_argument( "--record-type-swift-path", required=True, help="path of the record type enum swift file.", ) parser.add_argument( "--record-type-json-path", required=True, help="path of the record type map json file.", ) parser.add_argument( "--config-json-path", required=True, help="path of the json file with code generation config info.", ) parser.add_argument( "--property-order-json-path", required=True, help="path of the json file with property ordering cache.", ) args = parser.parse_args() global_args = args src_path = os.path.abspath(args.src_path) search_path = os.path.abspath(args.search_path) record_type_swift_path = os.path.abspath(args.record_type_swift_path) record_type_json_path = os.path.abspath(args.record_type_json_path) config_json_path = os.path.abspath(args.config_json_path) property_order_json_path = os.path.abspath(args.property_order_json_path) # We control the code generation process using a JSON config file. parse_config_json(config_json_path) parse_property_order_json(property_order_json_path) # The code generation needs to understand the class hierarchy so that # it can: # # * Define table schemas that include the superset of properties in # the model class hierarchies. # * Generate deserialization methods that handle all subclasses. # * etc. global_class_map.update(find_sds_intermediary_files_in_path(search_path)) update_subclass_map() update_record_type_map(record_type_swift_path, record_type_json_path) process_class_map(find_sds_intermediary_files_in_path(src_path)) # Persist updated property order update_property_order_json(property_order_json_path)