Signal-iOS/Scripts/sds_codegen/sds_generate.py
2022-03-24 10:28:27 -05:00

2441 lines
92 KiB
Python
Executable File

#!/usr/bin/env python3
import os
import subprocess
import argparse
import re
import json
import sds_common
from sds_common import fail
import random
# TODO: We should probably generate a class that knows how to set up
# the database. It would:
#
# * Create all tables (or apply database schema).
# * Register renamed classes.
# [NSKeyedUnarchiver setClass:[OWSUserProfile class] forClassName:[OWSUserProfile collection]];
# [NSKeyedUnarchiver setClass:[OWSDatabaseMigration class] forClassName:[OWSDatabaseMigration collection]];
# We consider any subclass of TSYapDatabaseObject to be a "serializable model".
#
# We treat direct subclasses of TSYapDatabaseObject as "roots" of the model class hierarchy.
# Only root models do deserialization.
OLD_BASE_MODEL_CLASS_NAME = 'TSYapDatabaseObject'
NEW_BASE_MODEL_CLASS_NAME = 'BaseModel'
CODE_GEN_SNIPPET_MARKER_OBJC = '// --- CODE GENERATION MARKER'
# GRDB seems to encode non-primitive using JSON.
# GRDB chokes when decodes this JSON, due to it being a JSON "fragment".
# Either this is a bug in GRDB or we're using GRDB incorrectly.
# Until we resolve this issue, we need to encode/decode
# non-primitives ourselves.
USE_CODABLE_FOR_PRIMITIVES = False
USE_CODABLE_FOR_NONPRIMITIVES = False
def update_generated_snippet(file_path, marker, snippet):
# file_path = sds_common.sds_from_relative_path(relative_path)
if not os.path.exists(file_path):
fail('Missing file:', file_path)
with open(file_path, 'rt') as f:
text = f.read()
start_index = text.find(marker)
end_index = text.rfind(marker)
if start_index < 0 or end_index < 0 or start_index >= end_index:
fail('Could not find markers:', file_path)
text = text[:start_index].strip() + '\n\n' + marker + '\n\n' + snippet + '\n\n' + marker + '\n\n' + text[end_index + len(marker):].lstrip()
sds_common.write_text_file_if_changed(file_path, text)
def update_objc_snippet(file_path, snippet):
snippet = sds_common.clean_up_generated_objc(snippet).strip()
if len(snippet) < 1:
return
snippet = '// This snippet is generated by %s. Do not manually edit it, instead run `sds_codegen.sh`.' % ( sds_common.pretty_module_path(__file__), ) + '\n\n' + snippet
update_generated_snippet(file_path, CODE_GEN_SNIPPET_MARKER_OBJC, snippet)
# ----
global_class_map = {}
global_subclass_map = {}
global_args = None
# ----
def to_swift_identifier_name(identifier_name):
return identifier_name[0].lower() + identifier_name[1:]
class ParsedClass:
def __init__(self, json_dict):
self.name = json_dict.get('name')
self.super_class_name = json_dict.get('super_class_name')
self.filepath = sds_common.sds_from_relative_path(json_dict.get('filepath'))
self.finalize_method_name = json_dict.get('finalize_method_name')
self.property_map = {}
for property_dict in json_dict.get('properties'):
property = ParsedProperty(property_dict)
property.class_name = self.name
# TODO: We should handle all properties?
if property.should_ignore_property():
print('Ignoring property:', property.name)
continue
self.property_map[property.name] = property
def properties(self):
result = []
for name in sorted(self.property_map.keys()):
result.append(self.property_map[name])
return result
def database_subclass_properties(self):
# More than one subclass of a SDS model may declare properties
# with the same name. This is fine, so long as they have
# the same type.
all_property_map = {}
subclass_property_map = {}
root_property_names = set()
# print 'properties from:', clazz.name
for property in self.properties():
all_property_map[property.name] = property
root_property_names.add(property.name)
for subclass in all_descendents_of_class(self):
if should_ignore_class(subclass):
continue
# print 'properties from subclass:', subclass.name
for property in subclass.properties():
duplicate_property = all_property_map.get(property.name)
if duplicate_property is not None:
# print '\t', 'duplicate', property.name
if property.swift_type_safe() != duplicate_property.swift_type_safe():
print('property:', property.class_name, property.name, property.swift_type_safe(), property.is_optional)
print('duplicate_property:', duplicate_property.class_name, duplicate_property.name, duplicate_property.swift_type_safe(), duplicate_property.is_optional)
fail("Duplicate property doesn't match:", property.name)
elif property.is_optional != duplicate_property.is_optional:
if property.name in root_property_names:
print('property:', property.class_name, property.name, property.swift_type_safe(), property.is_optional)
print('duplicate_property:', duplicate_property.class_name, duplicate_property.name, duplicate_property.swift_type_safe(), duplicate_property.is_optional)
fail("Duplicate property doesn't match:", property.name)
# If one subclass property is optional and the other isn't, we should
# treat both as optional for the purposes of the database schema.
if not property.is_optional:
continue
else:
continue
# print 'adding', property.name
all_property_map[property.name] = property
subclass_property_map[property.name] = property
result = []
for name in sorted(subclass_property_map.keys()):
result.append(subclass_property_map[name])
return result
def record_id_source(self):
for property in self.properties():
if property.name == 'sortId':
return property.name
return None
def is_sds_model(self):
if self.super_class_name is None:
# print 'is_sds_model (1):', self.name, self.super_class_name
return False
if not self.super_class_name in global_class_map:
# print 'is_sds_model (2):', self.name, self.super_class_name
return False
if self.super_class_name in (OLD_BASE_MODEL_CLASS_NAME, NEW_BASE_MODEL_CLASS_NAME, ):
# print 'is_sds_model (3):', self.name, self.super_class_name
return True
super_class = global_class_map[self.super_class_name]
# print 'is_sds_model (4):', self.name, self.super_class_name
return super_class.is_sds_model()
def has_sds_superclass(self):
# print 'has_sds_superclass'
# print 'self.super_class_name:', self.super_class_name, self.super_class_name in global_class_map, self.super_class_name != BASE_MODEL_CLASS_NAME
return (self.super_class_name and
self.super_class_name in global_class_map
and self.super_class_name != OLD_BASE_MODEL_CLASS_NAME
and self.super_class_name != NEW_BASE_MODEL_CLASS_NAME)
def table_superclass(self):
if self.super_class_name is None:
return self
if not self.super_class_name in global_class_map:
return self
if self.super_class_name == OLD_BASE_MODEL_CLASS_NAME:
return self
if self.super_class_name == NEW_BASE_MODEL_CLASS_NAME:
return self
super_class = global_class_map[self.super_class_name]
return super_class.table_superclass()
def all_superclass_names(self):
result = [self.name]
if self.super_class_name is not None:
if self.super_class_name in global_class_map:
super_class = global_class_map[self.super_class_name]
result += super_class.all_superclass_names()
return result
def has_any_superclass_with_name(self, name):
return name in self.all_superclass_names()
def should_generate_extensions(self):
if self.name in (OLD_BASE_MODEL_CLASS_NAME, NEW_BASE_MODEL_CLASS_NAME, ):
print('Ignoring class (1):', self.name)
return False
if should_ignore_class(self):
print('Ignoring class (2):', self.name)
return False
if not self.is_sds_model():
# Only write serialization extensions for SDS models.
print('Ignoring class (3):', self.name)
return False
# The migration should not be persisted in the data store.
if self.name in ('OWSDatabaseMigration', 'YDBDatabaseMigration', 'OWSResaveCollectionDBMigration', ):
print('Ignoring class (4):', self.name)
return False
if self.super_class_name in ('OWSDatabaseMigration', 'YDBDatabaseMigration', 'OWSResaveCollectionDBMigration', ):
print('Ignoring class (5):', self.name)
return False
return True
def record_name(self):
return remove_prefix_from_class_name(self.name) + 'Record'
def sorted_record_properties(self):
record_name = self.record_name()
# If a property has a custom column source, we don't redundantly create a column for that column
base_properties = [property for property in self.properties() if not property.has_custom_column_source()]
# If a property has a custom column source, we don't redundantly create a column for that column
subclass_properties = [property for property in self.database_subclass_properties() if not property.has_custom_column_source()]
# We need to maintain a stable ordering of record properties
# across migrations, e.g. adding new columns to the tables.
#
# First, we build a list of "model" properties. This is the
# the superset of properties in the model base class and all
# of its subclasses.
#
# NOTE: We punch two values onto these properties:
# force_optional and property_order.
record_properties = []
for property in base_properties:
property.force_optional=False
record_properties.append(property)
for property in subclass_properties:
# We must "force" subclass properties to be optional
# since they don't apply to the base model and other
# subclasses.
property.force_optional=True
record_properties.append(property)
for property in record_properties:
# Try to load the known "order" for each property.
#
# "Orders" are indices used to ensure a stable ordering.
# We find the "orders" of all properties that already have
# one.
#
# This will initially be nil for new properties
# which have not yet been assigned an order.
property.property_order = property_order_for_property(property, record_name)
all_property_orders = [property.property_order for property in record_properties if property.property_order]
# We determine the "next" order we would assign to any
# new property without an order.
next_property_order = 1 + (max(all_property_orders) if len(all_property_orders) > 0 else 0)
# Pre-sort model properties by name, so that if we add more
# than one at a time they are nicely (and stable-y) sorted
# in an attractive way.
record_properties.sort(key=lambda value: value.name)
# Now iterate over all model properties and assign an order
# to any new properties without one.
for property in record_properties:
if property.property_order is None:
property.property_order = next_property_order
# We "set" the order in the mapping which is persisted
# as JSON to ensure continuity.
set_property_order_for_property(property, record_name, next_property_order)
next_property_order = next_property_order + 1
# Now sort the model properties, applying the ordering.
record_properties.sort(key=lambda value: value.property_order)
return record_properties
class TypeInfo:
def __init__(self, swift_type, objc_type, should_use_blob = False, is_codable = False, is_enum = False, field_override_column_type = None, field_override_record_swift_type = None):
self._swift_type = swift_type
self._objc_type = objc_type
self.should_use_blob = should_use_blob
self.is_codable = is_codable
self.is_enum = is_enum
self.field_override_column_type = field_override_column_type
self.field_override_record_swift_type = field_override_record_swift_type
def swift_type(self):
return str(self._swift_type)
def objc_type(self):
return str(self._objc_type)
# This defines the mapping of Swift types to database column types.
# We'll be iterating on this mapping.
# Note that we currently store all sub-models and collections (e.g. [String]) as a blob.
#
# TODO:
def database_column_type(self, value_name):
# print 'self._swift_type', self._swift_type, self._objc_type
if self.field_override_column_type is not None:
return self.field_override_column_type
elif self.should_use_blob or self.is_codable:
return '.blob'
elif self.is_enum:
return '.int'
elif self._swift_type == 'String':
return '.unicodeString'
elif self._objc_type == 'NSDate *':
# Persist dates as NSTimeInterval timeIntervalSince1970.
return '.double'
elif self._swift_type == 'Date':
# Persist dates as NSTimeInterval timeIntervalSince1970.
fail('We should not use `Date` as a "swift type" since all NSDates are serialized as doubles.', self._swift_type)
elif self._swift_type == 'Data':
return '.blob'
elif self._swift_type in ('Boolouble', 'Bool'):
return '.int'
elif self._swift_type in ('Double', 'Float'):
return '.double'
elif self.is_numeric():
return '.int64'
else:
fail('Unknown type(1):', self._swift_type)
def is_numeric(self):
# TODO: We need to revisit how we serialize numeric types.
return self._swift_type in (
# 'signed char',
'Bool',
'UInt64',
'UInt',
'Int64',
'Int',
'Int32',
'UInt32',
'Double',
'Float'
)
def should_cast_to_swift(self):
if self._swift_type in ('Bool', 'Int64', 'UInt64',):
return False
return self.is_numeric()
def deserialize_record_invocation(self, property, value_name, is_optional, did_force_optional):
custom_column_name = custom_column_name_for_property(property)
if custom_column_name is not None:
value_expr = 'record.%s' % ( custom_column_name, )
else:
value_expr = 'record.%s' % ( value_name, )
deserialization_optional = None
deserialization_not_optional = None
deserialization_conversion = ''
if self._swift_type == 'String':
deserialization_not_optional = 'required'
elif self._objc_type == 'NSDate *':
pass
elif self._swift_type == 'Date':
fail('Unknown type(0):', self._swift_type)
elif self.is_codable:
deserialization_not_optional = 'required'
elif self._swift_type == 'Data':
deserialization_optional = 'optionalData'
deserialization_not_optional = 'required'
elif self.is_numeric():
deserialization_optional = 'optionalNumericAsNSNumber'
deserialization_not_optional = 'required'
deserialization_conversion = ', conversion: { NSNumber(value: $0) }'
if is_optional:
if deserialization_optional is not None:
value_expr = 'SDSDeserialization.%s(%s, name: "%s"%s)' % ( deserialization_optional, value_expr, value_name, deserialization_conversion)
elif did_force_optional:
if deserialization_not_optional is not None:
value_expr = 'try SDSDeserialization.%s(%s, name: "%s")' % ( deserialization_not_optional, value_expr, value_name)
else:
# Do nothing; we don't need to unpack this non-optional.
pass
initializer_param_type = self.swift_type()
if is_optional:
initializer_param_type = initializer_param_type + '?'
# Special case this oddball type.
if property.has_custom_column_source():
value_expr = property.column_source()
value_expr = 'record.%s' % ( value_expr, )
# Special-case the unpacking of the auto-incremented
# primary key.
if value_expr == 'record.id':
value_expr = 'recordId'
value_statement = 'let %s: %s = %s(%s)' % ( value_name, initializer_param_type, initializer_param_type, value_expr, )
elif value_name == 'conversationColorName':
value_statement = 'let %s: %s = ConversationColorName(rawValue: %s)' % ( value_name, "ConversationColorName", value_expr, )
elif value_name == 'mentionNotificationMode':
value_statement = 'let %s: %s = TSThreadMentionNotificationMode(rawValue: %s) ?? .default' % ( value_name, "TSThreadMentionNotificationMode", value_expr, )
elif self.is_codable:
value_statement = 'let %s: %s = %s' % ( value_name, initializer_param_type, value_expr, )
elif self.should_use_blob:
blob_name = '%sSerialized' % ( str(value_name), )
if is_optional or did_force_optional:
serialized_statement = 'let %s: Data? = %s' % ( blob_name, value_expr, )
else:
serialized_statement = 'let %s: Data = %s' % ( blob_name, value_expr, )
if is_optional:
value_statement = 'let %s: %s? = try SDSDeserialization.optionalUnarchive(%s, name: "%s")' % ( value_name, self._swift_type, blob_name, value_name, )
else:
value_statement = 'let %s: %s = try SDSDeserialization.unarchive(%s, name: "%s")' % ( value_name, self._swift_type, blob_name, value_name, )
return [ serialized_statement, value_statement,]
elif self.is_enum and did_force_optional and not is_optional:
return [
'guard let %s: %s = %s else {' % ( value_name, initializer_param_type, value_expr, ),
' throw SDSError.missingRequiredField',
'}',
]
elif is_optional and self._objc_type == 'NSNumber *':
return [
'let %s: %s = %s' % ( value_name, 'NSNumber?', value_expr, ),
# 'let %sRaw = %s' % ( value_name, value_expr, ),
# 'var %s : NSNumber?' % ( value_name, ),
# 'if let value = %sRaw {' % ( value_name, ),
# ' %s = NSNumber(value: value)' % ( value_name, ),
# '}',
]
elif self._objc_type == 'NSDate *':
# Persist dates as NSTimeInterval timeIntervalSince1970.
interval_name = '%sInterval' % ( str(value_name), )
if did_force_optional:
serialized_statements = [
'guard let %s: Double = %s else {' % ( interval_name, value_expr, ),
' throw SDSError.missingRequiredField',
'}',
]
elif is_optional:
serialized_statements = [ 'let %s: Double? = %s' % ( interval_name, value_expr, ), ]
else:
serialized_statements = [ 'let %s: Double = %s' % ( interval_name, value_expr, ), ]
if is_optional:
value_statement = 'let %s: Date? = SDSDeserialization.optionalDoubleAsDate(%s, name: "%s")' % ( value_name, interval_name, value_name, )
else:
value_statement = 'let %s: Date = SDSDeserialization.requiredDoubleAsDate(%s, name: "%s")' % ( value_name, interval_name, value_name, )
return serialized_statements + [ value_statement,]
else:
value_statement = 'let %s: %s = %s' % ( value_name, initializer_param_type, value_expr, )
return [value_statement,]
def serialize_record_invocation(self, property, value_name, is_optional, did_force_optional):
value_expr = value_name
if property.field_override_serialize_record_invocation() is not None:
return property.field_override_serialize_record_invocation() % ( value_expr, )
elif self.is_codable:
pass
elif self.should_use_blob:
# blob_name = '%sSerialized' % ( str(value_name), )
if is_optional or did_force_optional:
return 'optionalArchive(%s)' % ( value_expr, )
else:
return 'requiredArchive(%s)' % ( value_expr, )
elif self._objc_type == 'NSDate *':
if is_optional or did_force_optional:
return 'archiveOptionalDate(%s)' % ( value_expr, )
else:
return 'archiveDate(%s)' % ( value_expr, )
elif self._objc_type == 'NSNumber *':
# elif self.is_numeric():
conversion_map = {
'Int8': 'int8Value',
'UInt8': 'uint8Value',
'Int16': 'int16Value',
'UInt16': 'uint16Value',
'Int32': 'int32Value',
'UInt32': 'uint32Value',
'Int64': 'int64Value',
'UInt64': 'uint64Value',
'Float': 'floatValue',
'Double': 'doubleValue',
'Bool': 'boolValue',
'Int': 'intValue',
'UInt': 'uintValue',
}
conversion_method = conversion_map[self.swift_type()]
if conversion_method is None:
fail('Could not convert:', self.swift_type())
serialization_conversion = '{ $0.%s }' % ( conversion_method, )
if is_optional or did_force_optional:
return 'archiveOptionalNSNumber(%s, conversion: %s)' % ( value_expr, serialization_conversion, )
else:
return 'archiveNSNumber(%s, conversion: %s)' % ( value_expr, serialization_conversion, )
return value_expr
def record_field_type(self, value_name):
# Special case this oddball type.
if self.field_override_record_swift_type is not None:
return self.field_override_record_swift_type
elif self.is_codable:
pass
elif self.should_use_blob:
return 'Data'
return self.swift_type()
class ParsedProperty:
def __init__(self, json_dict):
self.name = json_dict.get('name')
self.is_optional = json_dict.get('is_optional')
self.objc_type = json_dict.get('objc_type')
self.class_name = json_dict.get('class_name')
self.swift_type = None
def try_to_convert_objc_primitive_to_swift(self, objc_type, unpack_nsnumber=True):
if objc_type is None:
fail('Missing type')
elif objc_type == 'NSString *':
return 'String'
elif objc_type == 'NSDate *':
# Persist dates as NSTimeInterval timeIntervalSince1970.
return 'Double'
elif objc_type == 'NSData *':
return 'Data'
elif objc_type == 'BOOL':
return 'Bool'
elif objc_type == 'NSInteger':
return 'Int'
elif objc_type == 'NSUInteger':
return 'UInt'
elif objc_type == 'int32_t':
return 'Int32'
elif objc_type == 'uint32_t':
return 'UInt32'
elif objc_type == 'int64_t':
return 'Int64'
elif objc_type == 'long long':
return 'Int64'
elif objc_type == 'unsigned long long':
return 'UInt64'
elif objc_type == 'uint64_t':
return 'UInt64'
elif objc_type == 'unsigned long':
return 'UInt64'
elif objc_type == 'unsigned int':
return 'UInt32'
elif objc_type == 'double':
return 'Double'
elif objc_type == 'float':
return 'Float'
elif objc_type == 'CGFloat':
return 'Double'
elif objc_type == 'NSNumber *':
if unpack_nsnumber:
return swift_type_for_nsnumber(self)
else:
return 'NSNumber'
else:
return None
# NOTE: This method recurses to unpack types like: NSArray<NSArray<SomeClassName *> *> *
def convert_objc_class_to_swift(self, objc_type, unpack_nsnumber=True):
if objc_type == 'id':
return 'AnyObject'
elif not objc_type.endswith(' *'):
return None
swift_primitive = self.try_to_convert_objc_primitive_to_swift(objc_type, unpack_nsnumber=unpack_nsnumber)
if swift_primitive is not None:
return swift_primitive
array_match = re.search(r'^NS(Mutable)?Array<(.+)> \*$', objc_type)
if array_match is not None:
split = array_match.group(2)
return '[' + self.convert_objc_class_to_swift(split, unpack_nsnumber=False) + ']'
dict_match = re.search(r'^NS(Mutable)?Dictionary<(.+),(.+)> \*$', objc_type)
if dict_match is not None:
split1 = dict_match.group(2).strip()
split2 = dict_match.group(3).strip()
return '[' + self.convert_objc_class_to_swift(split1, unpack_nsnumber=False) + ': ' + self.convert_objc_class_to_swift(split2, unpack_nsnumber=False) + ']'
ordered_set_match = re.search(r'^NSOrderedSet<(.+)> \*$', objc_type)
if ordered_set_match is not None:
# swift has no primitive for ordered set, so we lose the element type
return 'NSOrderedSet'
swift_type = objc_type[:-len(' *')]
if '<' in swift_type or '{' in swift_type or '*' in swift_type:
fail('Unexpected type:', objc_type)
return swift_type
def try_to_convert_objc_type_to_type_info(self):
objc_type = self.objc_type
if objc_type is None:
fail('Missing type')
elif self.field_override_swift_type():
return TypeInfo(self.field_override_swift_type(), objc_type, should_use_blob=self.field_override_should_use_blob(), is_enum=self.field_override_is_enum(), field_override_column_type=self.field_override_column_type(), field_override_record_swift_type=self.field_override_record_swift_type())
elif objc_type in enum_type_map:
enum_type = objc_type
return TypeInfo(enum_type, objc_type, is_enum=True)
elif objc_type.startswith('enum '):
enum_type = objc_type[len('enum '):]
return TypeInfo(enum_type, objc_type, is_enum=True)
swift_primitive = self.try_to_convert_objc_primitive_to_swift(objc_type)
if swift_primitive is not None:
return TypeInfo(swift_primitive, objc_type)
# print 'objc_type', objc_type
if objc_type in ('struct CGSize', 'struct CGRect', 'struct CGPoint', ):
objc_type = objc_type[len('struct '):]
swift_type = objc_type
return TypeInfo(swift_type, objc_type, should_use_blob=True, is_codable=USE_CODABLE_FOR_PRIMITIVES)
swift_type = self.convert_objc_class_to_swift(self.objc_type)
if swift_type is not None:
if self.is_objc_type_codable(objc_type):
# print '----- is_objc_type_codable true:', objc_type
return TypeInfo(swift_type, objc_type, should_use_blob=True, is_codable=False)
# print '----- is_objc_type_codable false:', objc_type
return TypeInfo(swift_type, objc_type, should_use_blob=True, is_codable=False)
fail('Unknown type(3):', self.class_name, self.objc_type, self.name)
# NOTE: This method recurses to unpack types like: NSArray<NSArray<SomeClassName *> *> *
def is_objc_type_codable(self, objc_type):
if not USE_CODABLE_FOR_PRIMITIVES:
return False
if objc_type in ('NSString *',):
return True
elif objc_type in ('struct CGSize', 'struct CGRect', 'struct CGPoint', ):
return True
elif self.field_override_is_objc_codable() is not None:
return self.field_override_is_objc_codable()
elif objc_type in enum_type_map:
return True
elif objc_type.startswith('enum '):
return True
if not USE_CODABLE_FOR_NONPRIMITIVES:
return False
array_match = re.search(r'^NS(Mutable)?Array<(.+)> \*$', objc_type)
if array_match is not None:
split = array_match.group(2)
return self.is_objc_type_codable(split)
dict_match = re.search(r'^NS(Mutable)?Dictionary<(.+),(.+)> \*$', objc_type)
if dict_match is not None:
split1 = dict_match.group(2).strip()
split2 = dict_match.group(3).strip()
return self.is_objc_type_codable(split1) and self.is_objc_type_codable(split2)
return False
def field_override_swift_type(self):
return self._field_override("swift_type")
def field_override_is_objc_codable(self):
return self._field_override("is_objc_codable")
def field_override_is_enum(self):
return self._field_override("is_enum")
def field_override_column_type(self):
return self._field_override("column_type")
def field_override_record_swift_type(self):
return self._field_override("record_swift_type")
def field_override_serialize_record_invocation(self):
return self._field_override("serialize_record_invocation")
def field_override_should_use_blob(self):
return self._field_override("should_use_blob")
def field_override_objc_initializer_type(self):
return self._field_override("objc_initializer_type")
def _field_override(self, override_field):
manually_typed_fields = configuration_json.get('manually_typed_fields')
if manually_typed_fields is None:
fail('Configuration JSON is missing manually_typed_fields')
key = self.class_name + '.' + self.name
if key in manually_typed_fields:
return manually_typed_fields[key][override_field]
else:
return None
def type_info(self):
if self.swift_type is not None:
should_use_blob = (self.swift_type.startswith('[') or self.swift_type.startswith('{') or is_swift_class_name(self.swift_type))
return TypeInfo(self.swift_type, objc_type, should_use_blob=should_use_blob, is_codable=USE_CODABLE_FOR_PRIMITIVES, field_override_column_type=self.field_override_column_type)
return self.try_to_convert_objc_type_to_type_info()
def swift_type_safe(self):
return self.type_info().swift_type()
def objc_type_safe(self):
if self.field_override_objc_initializer_type() is not None:
return self.field_override_objc_initializer_type()
result = self.type_info().objc_type()
if result.startswith('enum '):
result = result[len('enum '):]
return result
# if self.objc_type is None:
# fail("Don't know Obj-C type for:", self.name)
# return self.objc_type
def database_column_type(self):
return self.type_info().database_column_type(self.name)
def should_ignore_property(self):
return should_ignore_property(self)
def column_source(self):
custom_name = custom_property_column_source(self)
if custom_name is not None:
return custom_name
else:
return self.name
def has_custom_column_source(self):
return custom_property_column_source(self) is not None
def deserialize_record_invocation(self, value_name, did_force_optional):
return self.type_info().deserialize_record_invocation(self, value_name, self.is_optional, did_force_optional)
def deep_copy_record_invocation(self, value_name, did_force_optional):
swift_type = self.swift_type_safe()
objc_type = self.objc_type_safe()
is_optional = self.is_optional
model_accessor = accessor_name_for_property(self)
initializer_param_type = swift_type
if is_optional:
initializer_param_type = initializer_param_type + '?'
simple_type_map = {
'NSString *': 'String',
'NSNumber *': 'NSNumber',
'NSDate *': 'Date',
'NSData *': 'Data',
'CGSize': 'CGSize',
'CGRect': 'CGRect',
'CGPoint': 'CGPoint',
}
if objc_type in simple_type_map:
initializer_param_type = simple_type_map[objc_type]
if is_optional:
initializer_param_type += '?'
return ['let %s: %s = modelToCopy.%s' % ( value_name, initializer_param_type, model_accessor, ),]
can_shallow_copy = False
if self.type_info().is_numeric():
can_shallow_copy = True
elif self.is_enum():
can_shallow_copy = True
if can_shallow_copy:
return ['let %s: %s = modelToCopy.%s' % ( value_name, initializer_param_type, model_accessor, ),]
initializer_param_type = initializer_param_type.replace('AnyObject', 'Any')
if is_optional:
return [
'// NOTE: If this generates build errors, you made need to',
'// modify DeepCopy.swift to support this type.',
'//',
'// That might mean:',
'//',
'// * Implement DeepCopyable for this type (e.g. a model).',
'// * Modify DeepCopies.deepCopy() to support this type (e.g. a collection).',
'let %s: %s' % ( value_name, initializer_param_type, ),
'if let %sForCopy = modelToCopy.%s {' % ( value_name, model_accessor, ),
' %s = try DeepCopies.deepCopy(%sForCopy)' % ( value_name, value_name, ),
'} else {',
' %s = nil' % ( value_name, ),
'}',
]
else:
return [
'// NOTE: If this generates build errors, you made need to',
'// implement DeepCopyable for this type in DeepCopy.swift.',
'let %s: %s = try DeepCopies.deepCopy(modelToCopy.%s)' % ( value_name, initializer_param_type, model_accessor, ),
]
fail('I don\'t know how to deep copy this type: %s / %s' % ( objc_type, swift_type) )
def possible_class_type_for_property(self):
swift_type = self.swift_type_safe()
if swift_type in global_class_map:
return global_class_map[swift_type]
objc_type = self.objc_type_safe()
if objc_type.endswith(' *'):
objc_type = objc_type[:-2]
if objc_type in global_class_map:
return global_class_map[objc_type]
return None
def serialize_record_invocation(self, value_name, did_force_optional):
return self.type_info().serialize_record_invocation(self, value_name, self.is_optional, did_force_optional)
def record_field_type(self):
return self.type_info().record_field_type(self.name)
def is_enum(self):
return self.type_info().is_enum
def swift_identifier(self):
return to_swift_identifier_name(self.name)
def column_name(self):
custom_column_name = custom_column_name_for_property(self)
if custom_column_name is not None:
return custom_column_name
else:
return self.swift_identifier()
def ows_getoutput(cmd):
proc = subprocess.Popen(cmd,
stdout = subprocess.PIPE,
stderr = subprocess.PIPE,
)
stdout, stderr = proc.communicate()
return proc.returncode, stdout, stderr
# ---- Parsing
def properties_and_inherited_properties(clazz):
result = []
if clazz.super_class_name in global_class_map:
super_class = global_class_map[clazz.super_class_name]
result.extend(properties_and_inherited_properties(super_class))
result.extend(clazz.properties())
# for property in result:
# print '----', clazz.name, '----', property.name
return result
def generate_swift_extensions_for_model(clazz):
print('\t', 'processing', clazz.__dict__)
if not clazz.should_generate_extensions():
return
has_sds_superclass = clazz.has_sds_superclass()
print('\t', '\t', 'clazz.name', clazz.name, type(clazz.name))
print('\t', '\t', 'clazz.super_class_name', clazz.super_class_name)
print('\t', '\t', 'filepath', clazz.filepath)
print('\t', '\t', 'table_superclass', clazz.table_superclass().name)
print('\t', '\t', 'has_sds_superclass', has_sds_superclass)
swift_filename = os.path.basename(clazz.filepath)
swift_filename = swift_filename[:swift_filename.find('.')] + '+SDS.swift'
swift_filepath = os.path.join(os.path.dirname(clazz.filepath), swift_filename)
print('\t', '\t', 'swift_filepath', swift_filepath)
record_type = get_record_type(clazz)
print('\t', '\t', 'record_type', record_type)
# TODO: We'll need to import SignalServiceKit for non-SSK models.
swift_body = '''//
// Copyright (c) 2022 Open Whisper Systems. All rights reserved.
//
import Foundation
import GRDB
import SignalCoreKit
// NOTE: This file is generated by %s.
// Do not manually edit it, instead run `sds_codegen.sh`.
''' % ( sds_common.pretty_module_path(__file__), )
if not has_sds_superclass:
# If a property has a custom column source, we don't redundantly create a column for that column
base_properties = [property for property in clazz.properties() if not property.has_custom_column_source()]
# If a property has a custom column source, we don't redundantly create a column for that column
subclass_properties = [property for property in clazz.database_subclass_properties() if not property.has_custom_column_source()]
swift_body += '''
// MARK: - Record
'''
record_name = clazz.record_name()
swift_body += '''
public struct %s: SDSRecord {
public weak var delegate: SDSRecordDelegate?
public var tableMetadata: SDSTableMetadata {
%sSerializer.table
}
public static var databaseTableName: String {
%sSerializer.table.tableName
}
public var id: Int64?
// This defines all of the columns used in the table
// where this model (and any subclasses) are persisted.
public let recordType: SDSRecordType
public let uniqueId: String
''' % ( record_name, str(clazz.name), str(clazz.name), )
def write_record_property(property, force_optional=False):
column_name = property.swift_identifier()
# print 'property', property.swift_type_safe()
record_field_type = property.record_field_type()
is_optional = property.is_optional or force_optional
optional_split = '?' if is_optional else ''
custom_column_name = custom_column_name_for_property(property)
if custom_column_name is not None:
column_name = custom_column_name
return ''' public let %s: %s%s
''' % ( str(column_name), record_field_type, optional_split, )
record_properties = clazz.sorted_record_properties()
# Declare the model properties in the record.
if len(record_properties) > 0:
swift_body += '\n // Properties \n'
for property in record_properties:
swift_body += write_record_property(property, force_optional=property.force_optional)
sds_properties = [
ParsedProperty({"name": "id", "is_optional": False, "objc_type": "NSInteger", "class_name": clazz.name}),
ParsedProperty({"name": "recordType", "is_optional": False, "objc_type": "NSUInteger", "class_name": clazz.name}),
ParsedProperty({"name": "uniqueId", "is_optional": False, "objc_type": "NSString *", "class_name": clazz.name})
]
# We use the pre-sorted collection record_properties so that
# we use the correct property order when generating:
#
# * CodingKeys
# * init(row: Row)
# * The table/column metadata.
persisted_properties = sds_properties + record_properties
swift_body += '''
public enum CodingKeys: String, CodingKey, ColumnExpression, CaseIterable {
'''
for property in persisted_properties:
custom_column_name = custom_column_name_for_property(property)
was_property_renamed = was_property_renamed_for_property(property)
if custom_column_name is not None:
if was_property_renamed:
swift_body += ''' case %s
''' % ( custom_column_name, )
else:
swift_body += ''' case %s = "%s"
''' % ( custom_column_name, property.swift_identifier(), )
else:
swift_body += ''' case %s
''' % ( property.swift_identifier(), )
swift_body += ''' }
'''
swift_body += '''
public static func columnName(_ column: %s.CodingKeys, fullyQualified: Bool = false) -> String {
fullyQualified ? "\(databaseTableName).\(column.rawValue)" : column.rawValue
}
public func didInsert(with rowID: Int64, for column: String?) {
guard let delegate = delegate else {
owsFailDebug("Missing delegate.")
return
}
delegate.updateRowId(rowID)
}
}
''' % ( record_name, )
swift_body += '''
// MARK: - Row Initializer
public extension %s {
static var databaseSelection: [SQLSelectable] {
CodingKeys.allCases
}
init(row: Row) {''' % (record_name)
for index, property in enumerate(persisted_properties):
swift_body += '''
%s = row[%s]''' % (property.column_name(), index)
swift_body += '''
}
}
'''
swift_body += '''
// MARK: - StringInterpolation
public extension String.StringInterpolation {
mutating func appendInterpolation(%(record_identifier)sColumn column: %(record_name)s.CodingKeys) {
appendLiteral(%(record_name)s.columnName(column))
}
mutating func appendInterpolation(%(record_identifier)sColumnFullyQualified column: %(record_name)s.CodingKeys) {
appendLiteral(%(record_name)s.columnName(column, fullyQualified: true))
}
}
''' % { 'record_identifier': record_identifier(clazz.name), 'record_name': record_name }
swift_body += '''
// MARK: - Deserialization
// TODO: Rework metadata to not include, for example, columns, column indices.
extension %s {
// This method defines how to deserialize a model, given a
// database row. The recordType column is used to determine
// the corresponding model class.
class func fromRecord(_ record: %s) throws -> %s {
''' % ( str(clazz.name), record_name, str(clazz.name), )
swift_body += '''
guard let recordId = record.id else {
throw SDSError.invalidValue
}
switch record.recordType {
'''
deserialize_classes = all_descendents_of_class(clazz) + [clazz]
deserialize_classes.sort(key=lambda value: value.name)
for deserialize_class in deserialize_classes:
if should_ignore_class(deserialize_class):
continue
initializer_params = []
objc_initializer_params = []
objc_super_initializer_args = []
objc_initializer_assigns = []
deserialize_record_type = get_record_type_enum_name(deserialize_class.name)
swift_body += ''' case .%s:
''' % ( str(deserialize_record_type), )
swift_body += '''
let uniqueId: String = record.uniqueId
'''
base_property_names = set()
for property in base_properties:
base_property_names.add(property.name)
deserialize_properties = properties_and_inherited_properties(deserialize_class)
has_local_properties = False
for property in deserialize_properties:
value_name = '%s' % property.name
if property.name not in ( 'uniqueId', ):
did_force_optional = (property.name not in base_property_names) and (not property.is_optional)
for statement in property.deserialize_record_invocation(value_name, did_force_optional):
# print 'statement', statement, type(statement)
swift_body += ' %s\n' % ( str(statement), )
initializer_params.append('%s: %s' % ( str(property.name), value_name, ) )
objc_initializer_type = str(property.objc_type_safe())
if objc_initializer_type.startswith('NSMutable'):
objc_initializer_type = 'NS' + objc_initializer_type[len('NSMutable'):]
if property.is_optional:
objc_initializer_type = 'nullable ' + objc_initializer_type
objc_initializer_params.append('%s:(%s)%s' % ( str(property.name), objc_initializer_type, str(property.name), ) )
is_superclass_property = property.class_name != deserialize_class.name
if is_superclass_property:
objc_super_initializer_args.append('%s:%s' % ( str(property.name), str(property.name), ) )
else:
has_local_properties = True
if str(property.objc_type_safe()).startswith('NSMutableArray'):
objc_initializer_assigns.append('_%s = %s ? [%s mutableCopy] : [NSMutableArray new];' % ( str(property.name), str(property.name), str(property.name), ) )
elif str(property.objc_type_safe()).startswith('NSMutableDictionary'):
objc_initializer_assigns.append('_%s = %s ? [%s mutableCopy] : [NSMutableDictionary new];' % ( str(property.name), str(property.name), str(property.name), ) )
else:
objc_initializer_assigns.append('_%s = %s;' % ( str(property.name), str(property.name), ) )
# --- Initializer Snippets
h_snippet = ''
h_snippet += '''
// clang-format off
- (instancetype)initWithGrdbId:(int64_t)grdbId
uniqueId:(NSString *)uniqueId
'''
for objc_initializer_param in objc_initializer_params[1:]:
alignment = max(0, len('- (instancetype)initWithUniqueId') - objc_initializer_param.index(':'))
h_snippet += (' ' * alignment) + objc_initializer_param + '\n'
h_snippet += 'NS_DESIGNATED_INITIALIZER NS_SWIFT_NAME(init(grdbId:%s:));\n' % ':'.join([str(property.name) for property in deserialize_properties])
h_snippet += '''
// clang-format on
'''
m_snippet = ''
m_snippet += '''
// clang-format off
- (instancetype)initWithGrdbId:(int64_t)grdbId
uniqueId:(NSString *)uniqueId
'''
for objc_initializer_param in objc_initializer_params[1:]:
alignment = max(0, len('- (instancetype)initWithUniqueId') - objc_initializer_param.index(':'))
m_snippet += (' ' * alignment) + objc_initializer_param + '\n'
if len(objc_super_initializer_args) == 1:
suffix = '];'
else:
suffix = ''
m_snippet += '''{
self = [super initWithGrdbId:grdbId
uniqueId:uniqueId%s
''' % (suffix)
for index, objc_super_initializer_arg in enumerate(objc_super_initializer_args[1:]):
alignment = max(0, len(' self = [super initWithUniqueId') - objc_super_initializer_arg.index(':'))
if index == len(objc_super_initializer_args) - 2:
suffix = '];'
else:
suffix = ''
m_snippet += (' ' * alignment) + objc_super_initializer_arg + suffix + '\n'
m_snippet += '''
if (!self) {
return self;
}
'''
for objc_initializer_assign in objc_initializer_assigns:
m_snippet += (' ' * 4) + objc_initializer_assign + '\n'
if deserialize_class.finalize_method_name is not None:
m_snippet += '''
[self %s];
''' % ( str(deserialize_class.finalize_method_name), )
m_snippet += '''
return self;
}
// clang-format on
'''
# Skip initializer generation for classes without any properties.
if not has_local_properties:
h_snippet = ''
m_snippet = ''
if deserialize_class.filepath.endswith('.m'):
m_filepath = deserialize_class.filepath
h_filepath = m_filepath[:-2] + '.h'
update_objc_snippet(h_filepath, h_snippet)
update_objc_snippet(m_filepath, m_snippet)
swift_body += '''
'''
# --- Invoke Initializer
initializer_invocation = ' return %s(' % str(deserialize_class.name)
swift_body += initializer_invocation
initializer_params = ['grdbId: recordId',] + initializer_params
swift_body += (',\n' + ' ' * len(initializer_invocation)).join(initializer_params)
swift_body += ')'
swift_body += '''
'''
# TODO: We could generate a comment with the Obj-C (or Swift) model initializer
# that this deserialization code expects.
swift_body += ''' default:
owsFailDebug("Unexpected record type: \(record.recordType)")
throw SDSError.invalidValue
'''
swift_body += ''' }
'''
swift_body += ''' }
'''
swift_body += '''}
'''
# TODO: Remove the serialization glue below.
if not has_sds_superclass:
swift_body += '''
// MARK: - SDSModel
extension %s: SDSModel {
public var serializer: SDSSerializer {
// Any subclass can be cast to it's superclass,
// so the order of this switch statement matters.
// We need to do a "depth first" search by type.
switch self {''' % str(clazz.name)
for subclass in reversed(all_descendents_of_class(clazz)):
if should_ignore_class(subclass):
continue
swift_body += '''
case let model as %s:
assert(type(of: model) == %s.self)
return %sSerializer(model: model)''' % ( str(subclass.name), str(subclass.name), str(subclass.name), )
swift_body += '''
default:
return %sSerializer(model: self)
}
}
public func asRecord() throws -> SDSRecord {
try serializer.asRecord()
}
public var sdsTableName: String {
%s.databaseTableName
}
public static var table: SDSTableMetadata {
%sSerializer.table
}
}
''' % ( str(clazz.name), record_name, str(clazz.name), )
if not has_sds_superclass:
swift_body += '''
// MARK: - DeepCopyable
extension %(class_name)s: DeepCopyable {
public func deepCopy() throws -> AnyObject {
// Any subclass can be cast to it's superclass,
// so the order of this switch statement matters.
// We need to do a "depth first" search by type.
guard let id = self.grdbId?.int64Value else {
throw OWSAssertionError("Model missing grdbId.")
}
''' % { "class_name": str(clazz.name) }
# switch self {''' % { "class_name": str(clazz.name) }
classes_to_copy = list(reversed(all_descendents_of_class(clazz))) + [clazz,]
for class_to_copy in classes_to_copy:
if should_ignore_class(class_to_copy):
continue
if class_to_copy == clazz:
swift_body += '''
do {
let modelToCopy = self
assert(type(of: modelToCopy) == %(class_name)s.self)
''' % { "class_name": str(class_to_copy.name) }
else:
swift_body += '''
if let modelToCopy = self as? %(class_name)s {
assert(type(of: modelToCopy) == %(class_name)s.self)
''' % { "class_name": str(class_to_copy.name) }
initializer_params = []
base_property_names = set()
for property in base_properties:
base_property_names.add(property.name)
deserialize_properties = properties_and_inherited_properties(class_to_copy)
for property in deserialize_properties:
value_name = '%s' % property.name
did_force_optional = (property.name not in base_property_names) and (not property.is_optional)
for statement in property.deep_copy_record_invocation(value_name, did_force_optional):
swift_body += ' %s\n' % ( str(statement), )
initializer_params.append('%s: %s' % ( str(property.name), value_name, ) )
swift_body += '''
'''
# --- Invoke Initializer
initializer_invocation = ' return %s(' % str(class_to_copy.name)
swift_body += initializer_invocation
initializer_params = ['grdbId: id',] + initializer_params
swift_body += (',\n' + ' ' * len(initializer_invocation)).join(initializer_params)
swift_body += ')'
swift_body += '''
}
'''
swift_body += '''
}
}
'''
if not has_sds_superclass:
swift_body += '''
// MARK: - Table Metadata
extension %sSerializer {
// This defines all of the columns used in the table
// where this model (and any subclasses) are persisted.
''' % str(clazz.name)
# Eventually we need a (persistent?) mechanism for guaranteeing
# consistency of column ordering, that is robust to schema
# changes, class hierarchy changes, etc.
column_property_names = []
def write_column_metadata(property, force_optional=False):
column_name = property.swift_identifier()
column_property_names.append(column_name)
is_optional = property.is_optional or force_optional
optional_split = ', isOptional: true' if is_optional else ''
is_unique = column_name == str('uniqueId')
is_unique_split = ', isUnique: true' if is_unique else ''
# print 'property', property.swift_type_safe()
database_column_type = property.database_column_type()
if property.name == 'id':
database_column_type = '.primaryKey'
# TODO: Use skipSelect.
return ''' static var %sColumn: SDSColumnMetadata { SDSColumnMetadata(columnName: "%s", columnType: %s%s%s) }
''' % ( str(column_name), str(column_name), database_column_type, optional_split, is_unique_split )
for property in sds_properties:
swift_body += write_column_metadata(property)
if len(record_properties) > 0:
swift_body += ' // Properties \n'
for property in record_properties:
swift_body += write_column_metadata(property, force_optional=property.force_optional)
database_table_name = 'model_%s' % str(clazz.name)
swift_body += '''
// TODO: We should decide on a naming convention for
// tables that store models.
public static var table: SDSTableMetadata {
SDSTableMetadata(collection: %s.collection(),
tableName: "%s",
columns: [
''' % ( str(clazz.name), database_table_name, )
for column_property_name in column_property_names:
swift_body += ''' %sColumn,
''' % ( str(column_property_name) )
swift_body += ''' ])
}
}
'''
# ---- Fetch ----
swift_body += '''
// MARK: - Save/Remove/Update
@objc
public extension %(class_name)s {
func anyInsert(transaction: SDSAnyWriteTransaction) {
sdsSave(saveMode: .insert, transaction: transaction)
}
// Avoid this method whenever feasible.
//
// If the record has previously been saved, this method does an overwriting
// update of the corresponding row, otherwise if it's a new record, this
// method inserts a new row.
//
// For performance, when possible, you should explicitly specify whether
// you are inserting or updating rather than calling this method.
func anyUpsert(transaction: SDSAnyWriteTransaction) {
let isInserting: Bool
if %(class_name)s.anyFetch(uniqueId: uniqueId, transaction: transaction) != nil {
isInserting = false
} else {
isInserting = true
}
sdsSave(saveMode: isInserting ? .insert : .update, transaction: transaction)
}
// This method is used by "updateWith..." methods.
//
// This model may be updated from many threads. We don't want to save
// our local copy (this instance) since it may be out of date. We also
// want to avoid re-saving a model that has been deleted. Therefore, we
// use "updateWith..." methods to:
//
// a) Update a property of this instance.
// b) If a copy of this model exists in the database, load an up-to-date copy,
// and update and save that copy.
// b) If a copy of this model _DOES NOT_ exist in the database, do _NOT_ save
// this local instance.
//
// After "updateWith...":
//
// a) Any copy of this model in the database will have been updated.
// b) The local property on this instance will always have been updated.
// c) Other properties on this instance may be out of date.
//
// All mutable properties of this class have been made read-only to
// prevent accidentally modifying them directly.
//
// This isn't a perfect arrangement, but in practice this will prevent
// data loss and will resolve all known issues.
func anyUpdate(transaction: SDSAnyWriteTransaction, block: (%(class_name)s) -> Void) {
block(self)
guard let dbCopy = type(of: self).anyFetch(uniqueId: uniqueId,
transaction: transaction) else {
return
}
// Don't apply the block twice to the same instance.
// It's at least unnecessary and actually wrong for some blocks.
// e.g. `block: { $0 in $0.someField++ }`
if dbCopy !== self {
block(dbCopy)
}
dbCopy.sdsSave(saveMode: .update, transaction: transaction)
}
// This method is an alternative to `anyUpdate(transaction:block:)` methods.
//
// We should generally use `anyUpdate` to ensure we're not unintentionally
// clobbering other columns in the database when another concurrent update
// has occurred.
//
// There are cases when this doesn't make sense, e.g. when we know we've
// just loaded the model in the same transaction. In those cases it is
// safe and faster to do a "overwriting" update
func anyOverwritingUpdate(transaction: SDSAnyWriteTransaction) {
sdsSave(saveMode: .update, transaction: transaction)
}
func anyRemove(transaction: SDSAnyWriteTransaction) {
sdsRemove(transaction: transaction)
}
func anyReload(transaction: SDSAnyReadTransaction) {
anyReload(transaction: transaction, ignoreMissing: false)
}
func anyReload(transaction: SDSAnyReadTransaction, ignoreMissing: Bool) {
guard let latestVersion = type(of: self).anyFetch(uniqueId: uniqueId, transaction: transaction) else {
if !ignoreMissing {
owsFailDebug("`latest` was unexpectedly nil")
}
return
}
setValuesForKeys(latestVersion.dictionaryValue)
}
}
''' % { "class_name": str(clazz.name) }
# ---- Cursor ----
swift_body += '''
// MARK: - %sCursor
@objc
public class %sCursor: NSObject, SDSCursor {
private let transaction: GRDBReadTransaction
private let cursor: RecordCursor<%s>?
init(transaction: GRDBReadTransaction, cursor: RecordCursor<%s>?) {
self.transaction = transaction
self.cursor = cursor
}
public func next() throws -> %s? {
guard let cursor = cursor else {
return nil
}
guard let record = try cursor.next() else {
return nil
}''' % ( str(clazz.name), str(clazz.name), record_name, record_name, str(clazz.name), )
cache_code = cache_set_code_for_class(clazz)
if cache_code is not None:
swift_body += '''
let value = try %s.fromRecord(record)
%s(value, transaction: transaction.asAnyRead)
return value''' % ( str(clazz.name), cache_code, )
else:
swift_body += '''
return try %s.fromRecord(record)''' % ( str(clazz.name), )
swift_body += '''
}
public func all() throws -> [%s] {
var result = [%s]()
while true {
guard let model = try next() else {
break
}
result.append(model)
}
return result
}
}
''' % ( str(clazz.name), str(clazz.name), )
# ---- Fetch ----
swift_body += '''
// MARK: - Obj-C Fetch
// TODO: We may eventually want to define some combination of:
//
// * fetchCursor, fetchOne, fetchAll, etc. (ala GRDB)
// * Optional "where clause" parameters for filtering.
// * Async flavors with completions.
//
// TODO: I've defined flavors that take a read transaction.
// Or we might take a "connection" if we end up having that class.
@objc
public extension %(class_name)s {
class func grdbFetchCursor(transaction: GRDBReadTransaction) -> %(class_name)sCursor {
let database = transaction.database
do {
let cursor = try %(record_name)s.fetchCursor(database)
return %(class_name)sCursor(transaction: transaction, cursor: cursor)
} catch {
owsFailDebug("Read failed: \(error)")
return %(class_name)sCursor(transaction: transaction, cursor: nil)
}
}
''' % { "class_name": str(clazz.name), "record_name": record_name }
swift_body += '''
// Fetches a single model by "unique id".
class func anyFetch(uniqueId: String,
transaction: SDSAnyReadTransaction) -> %(class_name)s? {
assert(uniqueId.count > 0)
''' % { "class_name": str(clazz.name), "record_name": record_name, "record_identifier": record_identifier(clazz.name) }
cache_code = cache_get_code_for_class(clazz)
if cache_code is not None:
swift_body += '''
return anyFetch(uniqueId: uniqueId, transaction: transaction, ignoreCache: false)
}
// Fetches a single model by "unique id".
class func anyFetch(uniqueId: String,
transaction: SDSAnyReadTransaction,
ignoreCache: Bool) -> %(class_name)s? {
assert(uniqueId.count > 0)
if !ignoreCache,
let cachedCopy = %(cache_code)s {
return cachedCopy
}
''' % { "class_name": str(clazz.name), "cache_code": str(cache_code), }
swift_body += '''
switch transaction.readTransaction {
case .grdbRead(let grdbTransaction):
let sql = "SELECT * FROM \(%(record_name)s.databaseTableName) WHERE \(%(record_identifier)sColumn: .uniqueId) = ?"
return grdbFetchOne(sql: sql, arguments: [uniqueId], transaction: grdbTransaction)
}
}
''' % { "record_name": record_name, "record_identifier": record_identifier(clazz.name) }
swift_body += '''
// Traverses all records.
// Records are not visited in any particular order.
class func anyEnumerate(transaction: SDSAnyReadTransaction,
block: @escaping (%s, UnsafeMutablePointer<ObjCBool>) -> Void) {
anyEnumerate(transaction: transaction, batched: false, block: block)
}
// Traverses all records.
// Records are not visited in any particular order.
class func anyEnumerate(transaction: SDSAnyReadTransaction,
batched: Bool = false,
block: @escaping (%s, UnsafeMutablePointer<ObjCBool>) -> Void) {
let batchSize = batched ? Batching.kDefaultBatchSize : 0
anyEnumerate(transaction: transaction, batchSize: batchSize, block: block)
}
// Traverses all records.
// Records are not visited in any particular order.
//
// If batchSize > 0, the enumeration is performed in autoreleased batches.
class func anyEnumerate(transaction: SDSAnyReadTransaction,
batchSize: UInt,
block: @escaping (%s, UnsafeMutablePointer<ObjCBool>) -> Void) {
switch transaction.readTransaction {
case .grdbRead(let grdbTransaction):
let cursor = %s.grdbFetchCursor(transaction: grdbTransaction)
Batching.loop(batchSize: batchSize,
loopBlock: { stop in
do {
guard let value = try cursor.next() else {
stop.pointee = true
return
}
block(value, stop)
} catch let error {
owsFailDebug("Couldn't fetch model: \(error)")
}
})
}
}
''' % ( ( str(clazz.name), ) * 4 )
swift_body += '''
// Traverses all records' unique ids.
// Records are not visited in any particular order.
class func anyEnumerateUniqueIds(transaction: SDSAnyReadTransaction,
block: @escaping (String, UnsafeMutablePointer<ObjCBool>) -> Void) {
anyEnumerateUniqueIds(transaction: transaction, batched: false, block: block)
}
// Traverses all records' unique ids.
// Records are not visited in any particular order.
class func anyEnumerateUniqueIds(transaction: SDSAnyReadTransaction,
batched: Bool = false,
block: @escaping (String, UnsafeMutablePointer<ObjCBool>) -> Void) {
let batchSize = batched ? Batching.kDefaultBatchSize : 0
anyEnumerateUniqueIds(transaction: transaction, batchSize: batchSize, block: block)
}
// Traverses all records' unique ids.
// Records are not visited in any particular order.
//
// If batchSize > 0, the enumeration is performed in autoreleased batches.
class func anyEnumerateUniqueIds(transaction: SDSAnyReadTransaction,
batchSize: UInt,
block: @escaping (String, UnsafeMutablePointer<ObjCBool>) -> Void) {
switch transaction.readTransaction {
case .grdbRead(let grdbTransaction):
grdbEnumerateUniqueIds(transaction: grdbTransaction,
sql: """
SELECT \(%sColumn: .uniqueId)
FROM \(%s.databaseTableName)
""",
batchSize: batchSize,
block: block)
}
}
''' % ( record_identifier(clazz.name), record_name, )
swift_body += '''
// Does not order the results.
class func anyFetchAll(transaction: SDSAnyReadTransaction) -> [%s] {
var result = [%s]()
anyEnumerate(transaction: transaction) { (model, _) in
result.append(model)
}
return result
}
// Does not order the results.
class func anyAllUniqueIds(transaction: SDSAnyReadTransaction) -> [String] {
var result = [String]()
anyEnumerateUniqueIds(transaction: transaction) { (uniqueId, _) in
result.append(uniqueId)
}
return result
}
''' % ( ( str(clazz.name), ) * 2 )
# ---- Count ----
swift_body += '''
class func anyCount(transaction: SDSAnyReadTransaction) -> UInt {
switch transaction.readTransaction {
case .grdbRead(let grdbTransaction):
return %s.ows_fetchCount(grdbTransaction.database)
}
}
''' % ( record_name, )
# ---- Remove All ----
swift_body += '''
// WARNING: Do not use this method for any models which do cleanup
// in their anyWillRemove(), anyDidRemove() methods.
class func anyRemoveAllWithoutInstantation(transaction: SDSAnyWriteTransaction) {
switch transaction.writeTransaction {
case .grdbWrite(let grdbTransaction):
do {
try %s.deleteAll(grdbTransaction.database)
} catch {
owsFailDebug("deleteAll() failed: \(error)")
}
}
if ftsIndexMode != .never {
FullTextSearchFinder.allModelsWereRemoved(collection: collection(), transaction: transaction)
}
}
class func anyRemoveAllWithInstantation(transaction: SDSAnyWriteTransaction) {
// To avoid mutationDuringEnumerationException, we need
// to remove the instances outside the enumeration.
let uniqueIds = anyAllUniqueIds(transaction: transaction)
var index: Int = 0
Batching.loop(batchSize: Batching.kDefaultBatchSize,
loopBlock: { stop in
guard index < uniqueIds.count else {
stop.pointee = true
return
}
let uniqueId = uniqueIds[index]
index = index + 1
guard let instance = anyFetch(uniqueId: uniqueId, transaction: transaction) else {
owsFailDebug("Missing instance.")
return
}
instance.anyRemove(transaction: transaction)
})
if ftsIndexMode != .never {
FullTextSearchFinder.allModelsWereRemoved(collection: collection(), transaction: transaction)
}
}
class func anyExists(uniqueId: String,
transaction: SDSAnyReadTransaction) -> Bool {
assert(uniqueId.count > 0)
switch transaction.readTransaction {
case .grdbRead(let grdbTransaction):
let sql = "SELECT EXISTS ( SELECT 1 FROM \(%s.databaseTableName) WHERE \(%sColumn: .uniqueId) = ? )"
let arguments: StatementArguments = [uniqueId]
return try! Bool.fetchOne(grdbTransaction.database, sql: sql, arguments: arguments) ?? false
}
}
}
''' % ( record_name, record_name, record_identifier(clazz.name), )
# ---- Fetch ----
swift_body += '''
// MARK: - Swift Fetch
public extension %(class_name)s {
class func grdbFetchCursor(sql: String,
arguments: StatementArguments = StatementArguments(),
transaction: GRDBReadTransaction) -> %(class_name)sCursor {
do {
let sqlRequest = SQLRequest<Void>(sql: sql, arguments: arguments, cached: true)
let cursor = try %(record_name)s.fetchCursor(transaction.database, sqlRequest)
return %(class_name)sCursor(transaction: transaction, cursor: cursor)
} catch {
Logger.verbose("sql: \(sql)")
owsFailDebug("Read failed: \(error)")
return %(class_name)sCursor(transaction: transaction, cursor: nil)
}
}
''' % { "class_name": str(clazz.name), "record_name": record_name }
string_interpolation_name = remove_prefix_from_class_name(clazz.name)
swift_body += '''
class func grdbFetchOne(sql: String,
arguments: StatementArguments = StatementArguments(),
transaction: GRDBReadTransaction) -> %s? {
assert(sql.count > 0)
do {
let sqlRequest = SQLRequest<Void>(sql: sql, arguments: arguments, cached: true)
guard let record = try %s.fetchOne(transaction.database, sqlRequest) else {
return nil
}
''' % ( str(clazz.name), record_name, )
cache_code = cache_set_code_for_class(clazz)
if cache_code is not None:
swift_body += '''
let value = try %s.fromRecord(record)
%s(value, transaction: transaction.asAnyRead)
return value''' % ( str(clazz.name), cache_code, )
else:
swift_body += '''
return try %s.fromRecord(record)''' % ( str(clazz.name), )
swift_body += '''
} catch {
owsFailDebug("error: \(error)")
return nil
}
}
}
'''
# ---- Typed Convenience Methods ----
if has_sds_superclass:
swift_body += '''
// MARK: - Typed Convenience Methods
@objc
public extension %s {
// NOTE: This method will fail if the object has unexpected type.
class func anyFetch%s(uniqueId: String,
transaction: SDSAnyReadTransaction) -> %s? {
assert(uniqueId.count > 0)
guard let object = anyFetch(uniqueId: uniqueId,
transaction: transaction) else {
return nil
}
guard let instance = object as? %s else {
owsFailDebug("Object has unexpected type: \(type(of: object))")
return nil
}
return instance
}
// NOTE: This method will fail if the object has unexpected type.
func anyUpdate%s(transaction: SDSAnyWriteTransaction, block: (%s) -> Void) {
anyUpdate(transaction: transaction) { (object) in
guard let instance = object as? %s else {
owsFailDebug("Object has unexpected type: \(type(of: object))")
return
}
block(instance)
}
}
}
''' % ( str(clazz.name), str(remove_prefix_from_class_name(clazz.name)), str(clazz.name), str(clazz.name), str(remove_prefix_from_class_name(clazz.name)), str(clazz.name), str(clazz.name), )
# ---- SDSModel ----
table_superclass = clazz.table_superclass()
table_class_name = str(table_superclass.name)
has_serializable_superclass = table_superclass.name != clazz.name
override_keyword = ''
swift_body += '''
// MARK: - SDSSerializer
// The SDSSerializer protocol specifies how to insert and update the
// row that corresponds to this model.
class %sSerializer: SDSSerializer {
private let model: %s
public required init(model: %s) {
self.model = model
}
''' % ( str(clazz.name), str(clazz.name), str(clazz.name), )
# --- To Record
root_class = clazz.table_superclass()
root_record_name = remove_prefix_from_class_name(root_class.name) + 'Record'
record_id_source = "model.grdbId?.int64Value"
if root_class.record_id_source() is not None:
record_id_source = "model.%(source)s > 0 ? Int64(model.%(source)s) : %(default_source)s" % {
"source": root_class.record_id_source(),
"default_source": record_id_source,
}
swift_body += '''
// MARK: - Record
func asRecord() throws -> SDSRecord {
let id: Int64? = %(record_id_source)s
let recordType: SDSRecordType = .%(record_type)s
let uniqueId: String = model.uniqueId
''' % { "record_type": get_record_type_enum_name(clazz.name), "record_id_source": record_id_source }
initializer_args = ['id', 'recordType', 'uniqueId', ]
initializer_value_names = []
for property in properties_and_inherited_properties(clazz):
initializer_value_names.append(property.name)
# print 'initializer_value_names', initializer_value_names
def write_record_property(property, force_optional=False):
optional_value = ''
if property.swift_identifier() in initializer_value_names:
did_force_optional = property.force_optional
model_accessor = accessor_name_for_property(property)
value_expr = property.serialize_record_invocation('model.%s' % ( model_accessor, ), did_force_optional)
optional_value = ' = %s' % ( value_expr, )
else:
optional_value = ' = nil'
# print 'property', property.swift_type_safe()
record_field_type = property.record_field_type()
is_optional = property.is_optional or force_optional
optional_split = '?' if is_optional else ''
initializer_args.append(property.column_name())
return ''' let %s: %s%s%s
''' % ( str(property.column_name()), record_field_type, optional_split, optional_value, )
root_record_properties = root_class.sorted_record_properties()
if len(root_record_properties) > 0:
swift_body += '\n // Properties \n'
for property in root_record_properties:
swift_body += write_record_property(property, force_optional=property.force_optional)
initializer_args = ['%s: %s' % ( arg, arg, ) for arg in initializer_args]
swift_body += '''
return %s(delegate: model, %s)
}
''' % ( root_record_name, ', '.join(initializer_args), )
swift_body += '''}
'''
if not has_sds_superclass:
swift_body += '''
// MARK: - Deep Copy
#if TESTABLE_BUILD
@objc
public extension %(model_name)s {
// We're not using this method at the moment,
// but we might use it for validation of
// other deep copy methods.
func deepCopyUsingRecord() throws -> %(model_name)s {
guard let record = try asRecord() as? %(record_name)s else {
throw OWSAssertionError("Could not convert to record.")
}
return try %(model_name)s.fromRecord(record)
}
}
#endif
''' % { "model_name": str(clazz.name), "record_name": clazz.record_name(), }
# print 'swift_body', swift_body
print('Writing:', swift_filepath)
swift_body = sds_common.clean_up_generated_swift(swift_body)
# Add some random whitespace to trigger the auto-formatter.
swift_body = swift_body + (' ' * random.randint(1, 100))
sds_common.write_text_file_if_changed(swift_filepath, swift_body)
def process_class_map(class_map):
print('processing', class_map)
for clazz in class_map.values():
generate_swift_extensions_for_model(clazz)
# ---- Record Type Map
record_type_map = {}
# It's critical that our "record type" values are consistent, even if we add/remove/rename model classes.
# Therefore we persist the mapping of known classes in a JSON file that is under source control.
def update_record_type_map(record_type_swift_path, record_type_json_path):
print('update_record_type_map')
record_type_map_filepath = record_type_json_path
if os.path.exists(record_type_map_filepath):
with open(record_type_map_filepath, 'rt') as f:
json_string = f.read()
json_data = json.loads(json_string)
record_type_map.update(json_data)
max_record_type = 0
for class_name in record_type_map:
if class_name.startswith('#'):
continue
record_type = record_type_map[class_name]
max_record_type = max(max_record_type, record_type)
for clazz in global_class_map.values():
if clazz.name not in record_type_map:
if not clazz.should_generate_extensions():
continue
max_record_type = int(max_record_type) + 1
record_type = max_record_type
record_type_map[clazz.name] = record_type
record_type_map['#comment'] = 'NOTE: This file is generated by %s. Do not manually edit it, instead run `sds_codegen.sh`.' % ( sds_common.pretty_module_path(__file__), )
json_string = json.dumps(record_type_map, sort_keys=True, indent=4)
sds_common.write_text_file_if_changed(record_type_map_filepath, json_string)
# TODO: We'll need to import SignalServiceKit for non-SSK classes.
swift_body = '''//
// Copyright © 2022 Signal. All rights reserved.
//
import Foundation
import GRDB
import SignalCoreKit
// NOTE: This file is generated by %s.
// Do not manually edit it, instead run `sds_codegen.sh`.
@objc
public enum SDSRecordType: UInt, CaseIterable {
''' % ( sds_common.pretty_module_path(__file__), )
record_type_pairs = []
for key in record_type_map.keys():
if key.startswith('#'):
# Ignore comments
continue
enum_name = get_record_type_enum_name(key)
record_type_pairs.append((str(enum_name), record_type_map[key]))
record_type_pairs.sort(key=lambda value: value[1])
for (enum_name, record_type_id) in record_type_pairs:
# print 'enum_name', enum_name
swift_body += ''' case %s = %s
''' % ( enum_name, str(record_type_id), )
swift_body += '''}
'''
# print 'swift_body', swift_body
swift_body = sds_common.clean_up_generated_swift(swift_body)
sds_common.write_text_file_if_changed(record_type_swift_path, swift_body)
def get_record_type(clazz):
return record_type_map[clazz.name]
def remove_prefix_from_class_name(class_name):
name = class_name
if name.startswith('TS'):
name = name[len('TS'):]
elif name.startswith('OWS'):
name = name[len('OWS'):]
elif name.startswith('SSK'):
name = name[len('SSK'):]
return name
def get_record_type_enum_name(class_name):
name = remove_prefix_from_class_name(class_name)
if name[0].isnumeric():
name = '_' + name
return to_swift_identifier_name(name)
def record_identifier(class_name):
name = remove_prefix_from_class_name(class_name)
return to_swift_identifier_name(name)
# ---- Column Ordering
column_ordering_map = {}
has_loaded_column_ordering_map = False
# ---- Parsing
enum_type_map = {}
def objc_type_for_enum(enum_name):
if enum_name not in enum_type_map:
print('enum_type_map', enum_type_map)
fail('Enum has unknown type:', enum_name)
enum_type = enum_type_map[enum_name]
return enum_type
def swift_type_for_enum(enum_name):
objc_type = objc_type_for_enum(enum_name)
if objc_type == 'NSInteger':
return 'Int'
elif objc_type == 'NSUInteger':
return 'UInt'
elif objc_type == 'int32_t':
return 'Int32'
elif objc_type == 'unsigned long long':
return 'uint64_t'
elif objc_type == 'unsigned long long':
return 'UInt64'
elif objc_type == 'unsigned long':
return 'UInt64'
elif objc_type == 'unsigned int':
return 'UInt'
else:
fail('Unknown objc type:', objc_type)
def parse_sds_json(file_path):
with open(file_path, 'rt') as f:
json_str = f.read()
json_data = json.loads(json_str)
# print 'json_data:', json_data
classes = json_data['classes']
class_map = {}
for class_dict in classes:
# print 'class_dict:', class_dict
clazz = ParsedClass(class_dict)
class_map[clazz.name] = clazz
enums = json_data['enums']
# print '---- enums', file_path
# print '---- enums', enums
enum_type_map.update(enums)
return class_map
def try_to_parse_file(file_path):
filename = os.path.basename(file_path)
# print 'filename', filename
_, file_extension = os.path.splitext(filename)
if filename.endswith(sds_common.SDS_JSON_FILE_EXTENSION):
# print 'filename:', filename
print('\t', 'found', file_path)
return parse_sds_json(file_path)
else:
return {}
def find_sds_intermediary_files_in_path(path):
print('find_sds_intermediary_files_in_path', path)
class_map = {}
if os.path.isfile(path):
class_map.update(try_to_parse_file(path))
else:
for rootdir, dirnames, filenames in os.walk(path):
for filename in filenames:
file_path = os.path.abspath(os.path.join(rootdir, filename))
class_map.update(try_to_parse_file(file_path))
return class_map
def update_subclass_map():
for clazz in global_class_map.values():
if clazz.super_class_name is not None:
subclasses = global_subclass_map.get(clazz.super_class_name, [])
subclasses.append(clazz)
global_subclass_map[clazz.super_class_name] = subclasses
def all_descendents_of_class(clazz):
result = []
# print 'descendents of:', clazz.name
# print '\t', global_subclass_map.get(clazz.name, [])
subclasses = global_subclass_map.get(clazz.name, [])
subclasses.sort(key=lambda value: value.name)
for subclass in subclasses:
result.append(subclass)
result.extend(all_descendents_of_class(subclass))
return result
def is_swift_class_name(swift_type):
return global_class_map.get(swift_type) is not None
# ---- Config JSON
configuration_json = {}
def parse_config_json(config_json_path):
print('config_json_path', config_json_path)
with open(config_json_path, 'rt') as f:
json_str = f.read()
json_data = json.loads(json_str)
global configuration_json
configuration_json = json_data
# We often use nullable NSNumber * for optional numerics (bool, int, int64, double, etc.).
# There's now way to infer which type we're boxing in NSNumber.
# Therefore, we need to specify that in the configuration JSON.
def swift_type_for_nsnumber(property):
nsnumber_types = configuration_json.get('nsnumber_types')
if nsnumber_types is None:
print('Suggestion: update: %s' % ( str(global_args.config_json_path), ))
fail('Configuration JSON is missing mapping for properties of type NSNumber.')
key = property.class_name + '.' + property.name
swift_type = nsnumber_types.get(key)
if swift_type is None:
print('Suggestion: update: %s' % ( str(global_args.config_json_path), ))
fail('Configuration JSON is missing mapping for properties of type NSNumber:', key)
return swift_type
# Some properties shouldn't get serialized.
# For now, there's just one: TSGroupModel.groupImage which is a UIImage.
# We might end up extending the serialization to handle images.
# Or we might store these as Data/NSData/blob.
# TODO:
def should_ignore_property(property):
properties_to_ignore = configuration_json.get('properties_to_ignore')
if properties_to_ignore is None:
fail('Configuration JSON is missing list of properties to ignore during serialization.')
key = property.class_name + '.' + property.name
return key in properties_to_ignore
def custom_property_column_source(property):
custom_names = configuration_json.get('custom_property_column_sources')
if custom_names is None:
fail('Configuration JSON is missing dict of custom_property_column_sources.')
key = property.class_name + '.' + property.name
return custom_names.get(key)
def cache_get_code_for_class(clazz):
code_map = configuration_json.get('class_cache_get_code')
if code_map is None:
fail('Configuration JSON is missing dict of class_cache_get_code.')
key = clazz.name
return code_map.get(key)
def cache_set_code_for_class(clazz):
code_map = configuration_json.get('class_cache_set_code')
if code_map is None:
fail('Configuration JSON is missing dict of class_cache_set_code.')
key = clazz.name
return code_map.get(key)
def should_ignore_class(clazz):
class_to_skip_serialization = configuration_json.get('class_to_skip_serialization')
if class_to_skip_serialization is None:
fail('Configuration JSON is missing list of classes to ignore during serialization.')
if clazz.name in class_to_skip_serialization:
return True
if clazz.super_class_name is None:
return False
if not clazz.super_class_name in global_class_map:
return False
super_clazz = global_class_map[clazz.super_class_name]
return should_ignore_class(super_clazz)
def accessor_name_for_property(property):
custom_accessors = configuration_json.get('custom_accessors')
if custom_accessors is None:
fail('Configuration JSON is missing list of custom property accessors.')
key = property.class_name + '.' + property.name
# print '--?--', key, custom_accessors.get(key, property.name)
return custom_accessors.get(key, property.name)
# include_renamed_columns
def custom_column_name_for_property(property):
custom_column_names = configuration_json.get('custom_column_names')
if custom_column_names is None:
fail('Configuration JSON is missing list of custom column names.')
key = property.class_name + '.' + property.name
# print '--?--', key, custom_accessors.get(key, property.name)
return custom_column_names.get(key)
def was_property_renamed_for_property(property):
renamed_column_names = configuration_json.get('renamed_column_names')
if renamed_column_names is None:
fail('Configuration JSON is missing list of renamed column names.')
key = property.class_name + '.' + property.name
# print '--?--', key, custom_accessors.get(key, property.name)
return renamed_column_names.get(key) is not None
# ---- Config JSON
property_order_json = {}
def parse_property_order_json(property_order_json_path):
print('property_order_json_path', property_order_json_path)
with open(property_order_json_path, 'rt') as f:
json_str = f.read()
json_data = json.loads(json_str)
global property_order_json
property_order_json = json_data
# It's critical that our "property order" is consistent, even if we add columns.
# Therefore we persist the "property order" for all known properties in a JSON file that is under source control.
def update_property_order_json(property_order_json_path):
property_order_json['#comment'] = 'NOTE: This file is generated by %s. Do not manually edit it, instead run `sds_codegen.sh`.' % ( sds_common.pretty_module_path(__file__), )
json_string = json.dumps(property_order_json, sort_keys=True, indent=4)
sds_common.write_text_file_if_changed(property_order_json_path, json_string)
def property_order_key(property, record_name):
return record_name + '.' + property.name
def property_order_for_property(property, record_name):
key = property_order_key(property, record_name)
result = property_order_json.get(key)
return result
def set_property_order_for_property(property, record_name, value):
key = property_order_key(property, record_name)
property_order_json[key] = value
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Generate Swift extensions.')
parser.add_argument('--src-path', required=True, help='used to specify a path to process.')
parser.add_argument('--search-path', required=True, help='used to specify a path to process.')
parser.add_argument('--record-type-swift-path', required=True, help='path of the record type enum swift file.')
parser.add_argument('--record-type-json-path', required=True, help='path of the record type map json file.')
parser.add_argument('--config-json-path', required=True, help='path of the json file with code generation config info.')
parser.add_argument('--property-order-json-path', required=True, help='path of the json file with property ordering cache.')
args = parser.parse_args()
global_args = args
src_path = os.path.abspath(args.src_path)
search_path = os.path.abspath(args.search_path)
record_type_swift_path = os.path.abspath(args.record_type_swift_path)
record_type_json_path = os.path.abspath(args.record_type_json_path)
config_json_path = os.path.abspath(args.config_json_path)
property_order_json_path = os.path.abspath(args.property_order_json_path)
# We control the code generation process using a JSON config file.
print()
print('Parsing Config')
parse_config_json(config_json_path)
print()
print('Parsing Config')
parse_property_order_json(property_order_json_path)
# The code generation needs to understand the class hierarchy so that
# it can:
#
# * Define table schemas that include the superset of properties in
# the model class hierarchies.
# * Generate deserialization methods that handle all subclasses.
# * etc.
print()
print('Parsing Global Class Map')
global_class_map.update(find_sds_intermediary_files_in_path(search_path))
print('global_class_map', global_class_map)
update_subclass_map()
print()
print('Parsing Record Type Map')
update_record_type_map(record_type_swift_path, record_type_json_path)
print()
print('Processing')
process_class_map(find_sds_intermediary_files_in_path(src_path))
# Persist updated property order
update_property_order_json(property_order_json_path)