Signal-iOS/Scripts/sds_codegen/sds_generate.py
2019-07-16 14:00:13 -03:00

1894 lines
70 KiB
Python
Executable File

#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os
import sys
import subprocess
import datetime
import argparse
import commands
import re
import json
import sds_common
from sds_common import fail
import random
# TODO: We should probably generate a class that knows how to set up
# the database. It would:
#
# * Create all tables (or apply database schema).
# * Register renamed classes.
# [NSKeyedUnarchiver setClass:[OWSUserProfile class] forClassName:[OWSUserProfile collection]];
# [NSKeyedUnarchiver setClass:[OWSDatabaseMigration class] forClassName:[OWSDatabaseMigration collection]];
# We consider any subclass of TSYapDatabaseObject to be a "serializable model".
#
# We treat direct subclasses of TSYapDatabaseObject as "roots" of the model class hierarchy.
# Only root models do deserialization.
OLD_BASE_MODEL_CLASS_NAME = 'TSYapDatabaseObject'
NEW_BASE_MODEL_CLASS_NAME = 'BaseModel'
CODE_GEN_SNIPPET_MARKER_OBJC = '// --- CODE GENERATION MARKER'
# GRDB seems to encode non-primitive using JSON.
# GRDB chokes when decodes this JSON, due to it being a JSON "fragment".
# Either this is a bug in GRDB or we're using GRDB incorrectly.
# Until we resolve this issue, we need to encode/decode
# non-primitives ourselves.
ONLY_USE_CODABLE_FOR_PRIMITIVES = True
def update_generated_snippet(file_path, marker, snippet):
# file_path = sds_common.sds_from_relative_path(relative_path)
if not os.path.exists(file_path):
fail('Missing file:', file_path)
with open(file_path, 'rt') as f:
text = f.read()
start_index = text.find(marker)
end_index = text.rfind(marker)
if start_index < 0 or end_index < 0 or start_index >= end_index:
fail('Could not find markers:', file_path)
text = text[:start_index].strip() + '\n\n' + marker + '\n\n' + snippet + '\n\n' + marker + '\n\n' + text[end_index + len(marker):].lstrip()
sds_common.write_text_file_if_changed(file_path, text)
def update_objc_snippet(file_path, snippet):
snippet = sds_common.clean_up_generated_objc(snippet).strip()
if len(snippet) < 1:
return
snippet = '// This snippet is generated by %s. Do not manually edit it, instead run `sds_codegen.sh`.' % ( sds_common.pretty_module_path(__file__), ) + '\n\n' + snippet
update_generated_snippet(file_path, CODE_GEN_SNIPPET_MARKER_OBJC, snippet)
# ----
global_class_map = {}
global_subclass_map = {}
global_args = None
# ----
def to_swift_identifier_name(identifier_name):
return identifier_name[0].lower() + identifier_name[1:]
class ParsedClass:
def __init__(self, json_dict):
self.name = json_dict.get('name')
self.super_class_name = json_dict.get('super_class_name')
self.filepath = sds_common.sds_from_relative_path(json_dict.get('filepath'))
self.finalize_method_name = json_dict.get('finalize_method_name')
self.property_map = {}
for property_dict in json_dict.get('properties'):
property = ParsedProperty(property_dict)
property.class_name = self.name
# TODO: We should handle all properties?
if property.should_ignore_property():
print 'Ignoring property:', property.name
continue
self.property_map[property.name] = property
def properties(self):
result = []
for name in sorted(self.property_map.keys()):
result.append(self.property_map[name])
return result
def database_subclass_properties(self):
# More than one subclass of a SDS model may declare properties
# with the same name. This is fine, so long as they have
# the same type.
all_property_map = {}
subclass_property_map = {}
root_property_names = set()
# print 'properties from:', clazz.name
for property in self.properties():
all_property_map[property.name] = property
root_property_names.add(property.name)
for subclass in all_descendents_of_class(self):
if should_ignore_class(subclass):
continue
# print 'properties from subclass:', subclass.name
for property in subclass.properties():
duplicate_property = all_property_map.get(property.name)
if duplicate_property is not None:
# print '\t', 'duplicate', property.name
if property.swift_type_safe() != duplicate_property.swift_type_safe():
print 'property:', property.class_name, property.name, property.swift_type_safe(), property.is_optional
print 'duplicate_property:', duplicate_property.class_name, duplicate_property.name, duplicate_property.swift_type_safe(), duplicate_property.is_optional
fail("Duplicate property doesn't match:", property.name)
elif property.is_optional != duplicate_property.is_optional:
if property.name in root_property_names:
print 'property:', property.class_name, property.name, property.swift_type_safe(), property.is_optional
print 'duplicate_property:', duplicate_property.class_name, duplicate_property.name, duplicate_property.swift_type_safe(), duplicate_property.is_optional
fail("Duplicate property doesn't match:", property.name)
# If one subclass property is optional and the other isn't, we should
# treat both as optional for the purposes of the database schema.
if not property.is_optional:
continue
else:
continue
# print 'adding', property.name
all_property_map[property.name] = property
subclass_property_map[property.name] = property
result = []
for name in sorted(subclass_property_map.keys()):
result.append(subclass_property_map[name])
return result
def is_sds_model(self):
if self.super_class_name is None:
# print 'is_sds_model (1):', self.name, self.super_class_name
return False
if not self.super_class_name in global_class_map:
# print 'is_sds_model (2):', self.name, self.super_class_name
return False
if self.super_class_name in (OLD_BASE_MODEL_CLASS_NAME, NEW_BASE_MODEL_CLASS_NAME, ):
# print 'is_sds_model (3):', self.name, self.super_class_name
return True
super_class = global_class_map[self.super_class_name]
# print 'is_sds_model (4):', self.name, self.super_class_name
return super_class.is_sds_model()
def has_sds_superclass(self):
# print 'has_sds_superclass'
# print 'self.super_class_name:', self.super_class_name, self.super_class_name in global_class_map, self.super_class_name != BASE_MODEL_CLASS_NAME
return (self.super_class_name and
self.super_class_name in global_class_map
and self.super_class_name != OLD_BASE_MODEL_CLASS_NAME
and self.super_class_name != NEW_BASE_MODEL_CLASS_NAME)
def table_superclass(self):
if self.super_class_name is None:
return self
if not self.super_class_name in global_class_map:
return self
if self.super_class_name == OLD_BASE_MODEL_CLASS_NAME:
return self
if self.super_class_name == NEW_BASE_MODEL_CLASS_NAME:
return self
super_class = global_class_map[self.super_class_name]
return super_class.table_superclass()
def should_generate_extensions(self):
if self.name in (OLD_BASE_MODEL_CLASS_NAME, NEW_BASE_MODEL_CLASS_NAME, ):
print 'Ignoring class (1):', self.name
return False
if should_ignore_class(self):
print 'Ignoring class (2):', self.name
return False
if not self.is_sds_model():
# Only write serialization extensions for SDS models.
print 'Ignoring class (3):', self.name
return False
# The migration should not be persisted in the data store.
if self.name in ('OWSDatabaseMigration', 'YDBDatabaseMigration', 'OWSResaveCollectionDBMigration', ):
print 'Ignoring class (4):', self.name
return False
if self.super_class_name in ('OWSDatabaseMigration', 'YDBDatabaseMigration', 'OWSResaveCollectionDBMigration', ):
print 'Ignoring class (5):', self.name
return False
return True
class TypeInfo:
def __init__(self, swift_type, objc_type, should_use_blob = False, is_codable = False, is_enum = False):
self._swift_type = swift_type
self._objc_type = objc_type
self.should_use_blob = should_use_blob
self.is_codable = is_codable
self.is_enum = is_enum
def swift_type(self):
return self._swift_type
def objc_type(self):
return self._objc_type
# This defines the mapping of Swift types to database column types.
# We'll be iterating on this mapping.
# Note that we currently store all sub-models and collections (e.g. [String]) as a blob.
#
# TODO:
def database_column_type(self, value_name):
# print 'self._swift_type', self._swift_type, self._objc_type
# Special case this oddball type.
if value_name == 'conversationColorName':
return '.unicodeString'
elif self.should_use_blob or self.is_codable:
return '.blob'
elif self.is_enum:
return '.int'
elif self._swift_type == 'String':
return '.unicodeString'
elif self._swift_type == 'Date':
return '.int64'
elif self._swift_type == 'Data':
return '.blob'
elif self._swift_type in ('Boolouble', 'Bool'):
return '.int'
elif self._swift_type in ('Double', 'Float'):
return '.double'
elif self.is_numeric():
return '.int64'
else:
fail('Unknown type(1):', self._swift_type)
def is_numeric(self):
# TODO: We need to revisit how we serialize numeric types.
return self._swift_type in (
# 'signed char',
'Bool',
'UInt64',
'UInt',
'Int64',
'Int',
'Int32',
'UInt32',
'Double',
'Float'
)
def should_cast_to_swift(self):
if self._swift_type in ('Bool', 'Int64', 'UInt64',):
return False
return self.is_numeric()
def deserialize_record_invocation(self, property, value_name, is_optional, did_force_optional):
custom_column_name = custom_column_name_for_property(property)
if custom_column_name is not None:
value_expr = 'record.%s' % ( custom_column_name, )
else:
value_expr = 'record.%s' % ( value_name, )
deserialization_optional = None
deserialization_not_optional = None
deserialization_conversion = ''
if self._swift_type == 'String':
deserialization_not_optional = 'required'
elif self._swift_type == 'Date':
deserialization_not_optional = 'required'
elif self.is_codable:
deserialization_not_optional = 'required'
elif self._swift_type == 'Data':
deserialization_optional = 'optionalData'
deserialization_not_optional = 'required'
elif self.is_numeric():
deserialization_optional = 'optionalNumericAsNSNumber'
deserialization_not_optional = 'required'
deserialization_conversion = ', conversion: { NSNumber(value: $0) }'
if is_optional:
if deserialization_optional is not None:
value_expr = 'SDSDeserialization.%s(%s, name: "%s"%s)' % ( deserialization_optional, value_expr, value_name, deserialization_conversion)
elif did_force_optional:
if deserialization_not_optional is not None:
value_expr = 'try SDSDeserialization.%s(%s, name: "%s")' % ( deserialization_not_optional, value_expr, value_name)
else:
# Do nothing; we don't need to unpack this non-optional.
pass
initializer_param_type = self.swift_type()
if is_optional:
initializer_param_type = initializer_param_type + '?'
# Special case this oddball type.
if property.has_custom_column_source():
value_expr = property.column_source()
value_expr = 'record.%s' % ( value_expr, )
# Special-case the unpacking of the auto-incremented
# primary key.
if value_expr == 'record.id':
value_expr = 'recordId'
value_statement = 'let %s: %s = %s(%s)' % ( value_name, initializer_param_type, initializer_param_type, value_expr, )
elif value_name == 'conversationColorName':
value_statement = 'let %s: %s = ConversationColorName(rawValue: %s)' % ( value_name, "ConversationColorName", value_expr, )
elif self.is_codable:
value_statement = 'let %s: %s = %s' % ( value_name, initializer_param_type, value_expr, )
elif self.should_use_blob:
blob_name = '%sSerialized' % ( str(value_name), )
if is_optional or did_force_optional:
serialized_statement = 'let %s: Data? = %s' % ( blob_name, value_expr, )
else:
serialized_statement = 'let %s: Data = %s' % ( blob_name, value_expr, )
if is_optional:
value_statement = 'let %s: %s? = try SDSDeserialization.optionalUnarchive(%s, name: "%s")' % ( value_name, self._swift_type, blob_name, value_name, )
else:
value_statement = 'let %s: %s = try SDSDeserialization.unarchive(%s, name: "%s")' % ( value_name, self._swift_type, blob_name, value_name, )
return [ serialized_statement, value_statement,]
elif self.is_enum and did_force_optional and not is_optional:
return [
'guard let %s: %s = %s else {' % ( value_name, initializer_param_type, value_expr, ),
' throw SDSError.missingRequiredField',
'}',
]
elif is_optional and self._objc_type == 'NSNumber *':
return [
'let %s: %s = %s' % ( value_name, 'NSNumber?', value_expr, ),
# 'let %sRaw = %s' % ( value_name, value_expr, ),
# 'var %s : NSNumber?' % ( value_name, ),
# 'if let value = %sRaw {' % ( value_name, ),
# ' %s = NSNumber(value: value)' % ( value_name, ),
# '}',
]
else:
value_statement = 'let %s: %s = %s' % ( value_name, initializer_param_type, value_expr, )
return [value_statement,]
def serialize_record_invocation(self, property, value_name, is_optional, did_force_optional):
value_expr = value_name
if value_name == 'model.conversationColorName':
return '%s.rawValue' % ( value_expr, )
elif self.is_codable:
pass
elif self.should_use_blob:
# blob_name = '%sSerialized' % ( str(value_name), )
if is_optional or did_force_optional:
return 'optionalArchive(%s)' % ( value_expr, )
else:
return 'requiredArchive(%s)' % ( value_expr, )
elif self._objc_type == 'NSNumber *':
# elif self.is_numeric():
conversion_map = {
'Int8': 'int8Value',
'UInt8': 'uint8Value',
'Int16': 'int16Value',
'UInt16': 'uint16Value',
'Int32': 'int32Value',
'UInt32': 'uint32Value',
'Int64': 'int64Value',
'UInt64': 'uint64Value',
'Float': 'floatValue',
'Double': 'doubleValue',
'Bool': 'boolValue',
'Int': 'intValue',
'UInt': 'uintValue',
}
conversion_method = conversion_map[self.swift_type()]
if conversion_method is None:
fail('Could not convert:', self.swift_type())
serialization_conversion = '{ $0.%s }' % ( conversion_method, )
if is_optional or did_force_optional:
return 'archiveOptionalNSNumber(%s, conversion: %s)' % ( value_expr, serialization_conversion, )
else:
return 'archiveNSNumber(%s, conversion: %s)' % ( value_expr, serialization_conversion, )
return value_expr
def record_field_type(self, value_name):
# Special case this oddball type.
if value_name == 'conversationColorName':
return 'String'
elif self.is_codable:
pass
elif self.should_use_blob:
return 'Data'
return self.swift_type()
class ParsedProperty:
def __init__(self, json_dict):
self.name = json_dict.get('name')
self.is_optional = json_dict.get('is_optional')
self.objc_type = json_dict.get('objc_type')
self.class_name = json_dict.get('class_name')
self.swift_type = None
def try_to_convert_objc_primitive_to_swift(self, objc_type, unpack_nsnumber=True):
if objc_type is None:
fail('Missing type')
elif objc_type == 'NSString *':
return 'String'
elif objc_type == 'NSDate *':
return 'Date'
elif objc_type == 'NSData *':
return 'Data'
elif objc_type == 'BOOL':
return 'Bool'
elif objc_type == 'NSInteger':
return 'Int'
elif objc_type == 'NSUInteger':
return 'UInt'
elif objc_type == 'int32_t':
return 'Int32'
elif objc_type == 'int64_t':
return 'Int64'
elif objc_type == 'long long':
return 'Int64'
elif objc_type == 'unsigned long long':
return 'UInt64'
elif objc_type == 'uint64_t':
return 'UInt64'
elif objc_type == 'unsigned long':
return 'UInt64'
elif objc_type == 'unsigned int':
return 'UInt32'
elif objc_type == 'double':
return 'Double'
elif objc_type == 'float':
return 'Float'
elif objc_type == 'CGFloat':
return 'Double'
elif objc_type == 'NSNumber *':
if unpack_nsnumber:
return swift_type_for_nsnumber(self)
else:
return 'NSNumber'
else:
return None
# NOTE: This method recurses to unpack types like: NSArray<NSArray<SomeClassName *> *> *
def convert_objc_class_to_swift(self, objc_type, unpack_nsnumber=True):
if not objc_type.endswith(' *'):
return None
swift_primitive = self.try_to_convert_objc_primitive_to_swift(objc_type, unpack_nsnumber=unpack_nsnumber)
if swift_primitive is not None:
return swift_primitive
array_match = re.search(r'^NS(Mutable)?Array<(.+)> \*$', objc_type)
if array_match is not None:
split = array_match.group(2)
return '[' + self.convert_objc_class_to_swift(split, unpack_nsnumber=False) + ']'
dict_match = re.search(r'^NS(Mutable)?Dictionary<(.+),(.+)> \*$', objc_type)
if dict_match is not None:
split1 = dict_match.group(2).strip()
split2 = dict_match.group(3).strip()
return '[' + self.convert_objc_class_to_swift(split1, unpack_nsnumber=False) + ': ' + self.convert_objc_class_to_swift(split2, unpack_nsnumber=False) + ']'
swift_type = objc_type[:-len(' *')]
if '<' in swift_type or '{' in swift_type or '*' in swift_type:
fail('Unexpected type:', objc_type)
return swift_type
def try_to_convert_objc_type_to_type_info(self):
objc_type = self.objc_type
if objc_type is None:
fail('Missing type')
elif is_flagged_as_enum_property(self):
enum_type = objc_type
return TypeInfo(enum_type, objc_type, is_enum=True)
elif objc_type in enum_type_map:
enum_type = objc_type
return TypeInfo(enum_type, objc_type, is_enum=True)
elif objc_type.startswith('enum '):
enum_type = objc_type[len('enum '):]
return TypeInfo(enum_type, objc_type, is_enum=True)
swift_primitive = self.try_to_convert_objc_primitive_to_swift(objc_type)
if swift_primitive is not None:
return TypeInfo(swift_primitive, objc_type)
# print 'objc_type', objc_type
if objc_type in ('struct CGSize', 'struct CGRect', 'struct CGPoint', ):
objc_type = objc_type[len('struct '):]
swift_type = objc_type
return TypeInfo(swift_type, objc_type, should_use_blob=True, is_codable=True)
swift_type = self.convert_objc_class_to_swift(self.objc_type)
if swift_type is not None:
if self.is_objc_type_codable(objc_type):
# print '----- is_objc_type_codable true:', objc_type
return TypeInfo(swift_type, objc_type, should_use_blob=True, is_codable=True)
# print '----- is_objc_type_codable false:', objc_type
return TypeInfo(swift_type, objc_type, should_use_blob=True, is_codable=False)
fail('Unknown type(3):', self.class_name, self.objc_type, self.name)
# NOTE: This method recurses to unpack types like: NSArray<NSArray<SomeClassName *> *> *
def is_objc_type_codable(self, objc_type):
if objc_type in ('NSString *',):
return True
elif objc_type in ('struct CGSize', 'struct CGRect', 'struct CGPoint', ):
return True
elif is_flagged_as_enum_property(self):
return True
elif objc_type in enum_type_map:
return True
elif objc_type.startswith('enum '):
return True
if ONLY_USE_CODABLE_FOR_PRIMITIVES:
return False
array_match = re.search(r'^NS(Mutable)?Array<(.+)> \*$', objc_type)
if array_match is not None:
split = array_match.group(2)
return self.is_objc_type_codable(split)
dict_match = re.search(r'^NS(Mutable)?Dictionary<(.+),(.+)> \*$', objc_type)
if dict_match is not None:
split1 = dict_match.group(2).strip()
split2 = dict_match.group(3).strip()
return self.is_objc_type_codable(split1) and self.is_objc_type_codable(split2)
return False
def type_info(self):
if self.swift_type is not None:
should_use_blob = (self.swift_type.startswith('[') or self.swift_type.startswith('{') or is_swift_class_name(self.swift_type))
return TypeInfo(self.swift_type, objc_type, should_use_blob=should_use_blob, is_codable=should_use_blob)
return self.try_to_convert_objc_type_to_type_info()
def swift_type_safe(self):
return self.type_info().swift_type()
def objc_type_safe(self):
# Special case this oddball type.
#
# TODO: We might want to handle this within TypeInfo.
if self.name == 'conversationColorName':
return 'ConversationColorName'
result = self.type_info().objc_type()
if result.startswith('enum '):
result = result[len('enum '):]
return result
# if self.objc_type is None:
# fail("Don't know Obj-C type for:", self.name)
# return self.objc_type
def database_column_type(self):
return self.type_info().database_column_type(self.name)
def should_ignore_property(self):
return should_ignore_property(self)
def column_source(self):
custom_name = custom_property_column_source(self)
if custom_name is not None:
return custom_name
else:
return self.name
def has_custom_column_source(self):
return custom_property_column_source(self) is not None
def deserialize_record_invocation(self, value_name, did_force_optional):
return self.type_info().deserialize_record_invocation(self, value_name, self.is_optional, did_force_optional)
def serialize_record_invocation(self, value_name, did_force_optional):
return self.type_info().serialize_record_invocation(self, value_name, self.is_optional, did_force_optional)
def record_field_type(self):
return self.type_info().record_field_type(self.name)
def is_enum(self):
return self.type_info().is_enum
def ows_getoutput(cmd):
proc = subprocess.Popen(cmd,
stdout = subprocess.PIPE,
stderr = subprocess.PIPE,
)
stdout, stderr = proc.communicate()
return proc.returncode, stdout, stderr
# ---- Parsing
def properties_and_inherited_properties(clazz):
result = []
if clazz.super_class_name in global_class_map:
super_class = global_class_map[clazz.super_class_name]
result.extend(properties_and_inherited_properties(super_class))
result.extend(clazz.properties())
# for property in result:
# print '----', clazz.name, '----', property.name
return result
def generate_swift_extensions_for_model(clazz):
print '\t', 'processing', clazz.__dict__
if not clazz.should_generate_extensions():
return
has_sds_superclass = clazz.has_sds_superclass()
print '\t', '\t', 'clazz.name', clazz.name, type(clazz.name)
print '\t', '\t', 'clazz.super_class_name', clazz.super_class_name
print '\t', '\t', 'filepath', clazz.filepath
print '\t', '\t', 'table_superclass', clazz.table_superclass().name
print '\t', '\t', 'has_sds_superclass', has_sds_superclass
swift_filename = os.path.basename(clazz.filepath)
swift_filename = swift_filename[:swift_filename.find('.')] + '+SDS.swift'
swift_filepath = os.path.join(os.path.dirname(clazz.filepath), swift_filename)
print '\t', '\t', 'swift_filepath', swift_filepath
record_type = get_record_type(clazz)
print '\t', '\t', 'record_type', record_type
# TODO: We'll need to import SignalServiceKit for non-SSK models.
swift_body = '''//
// Copyright (c) 2019 Open Whisper Systems. All rights reserved.
//
import Foundation
import GRDBCipher
import SignalCoreKit
// NOTE: This file is generated by %s.
// Do not manually edit it, instead run `sds_codegen.sh`.
''' % ( sds_common.pretty_module_path(__file__), )
if not has_sds_superclass:
# If a property has a custom column source, we don't redundantly create a column for that column
base_properties = [property for property in clazz.properties() if not property.has_custom_column_source()]
# If a property has a custom column source, we don't redundantly create a column for that column
subclass_properties = [property for property in clazz.database_subclass_properties() if not property.has_custom_column_source()]
swift_body += '''
// MARK: - Record
'''
record_name = remove_prefix_from_class_name(clazz.name) + 'Record'
swift_body += '''
public struct %s: SDSRecord {
public var tableMetadata: SDSTableMetadata {
return %sSerializer.table
}
public static let databaseTableName: String = %sSerializer.table.tableName
public var id: Int64?
// This defines all of the columns used in the table
// where this model (and any subclasses) are persisted.
public let recordType: SDSRecordType
public let uniqueId: String
''' % ( record_name, str(clazz.name), str(clazz.name), )
def write_record_property(property, force_optional=False):
column_name = to_swift_identifier_name(property.name)
# print 'property', property.swift_type_safe()
record_field_type = property.record_field_type()
is_optional = property.is_optional or force_optional
optional_split = '?' if is_optional else ''
custom_column_name = custom_column_name_for_property(property)
if custom_column_name is not None:
column_name = custom_column_name
return ''' public let %s: %s%s
''' % ( str(column_name), record_field_type, optional_split, )
if len(base_properties) > 0:
swift_body += '\n // Base class properties \n'
for property in base_properties:
# print 'base_properties:', property.name
swift_body += write_record_property(property)
if len(subclass_properties) > 0:
swift_body += '\n // Subclass properties \n'
for property in subclass_properties:
# print 'subclass_properties:', property.name
swift_body += write_record_property(property, force_optional=True)
swift_body += '''
public enum CodingKeys: String, CodingKey, ColumnExpression, CaseIterable {
case id
case recordType
case uniqueId
'''
for property in (base_properties + subclass_properties):
custom_column_name = custom_column_name_for_property(property)
if custom_column_name is not None:
swift_body += ''' case %s = "%s"
''' % ( custom_column_name, to_swift_identifier_name(property.name), )
else:
swift_body += ''' case %s
''' % ( to_swift_identifier_name(property.name), )
swift_body += ''' }
'''
swift_body += '''
public static func columnName(_ column: %s.CodingKeys, fullyQualified: Bool = false) -> String {
return fullyQualified ? "\(databaseTableName).\(column.rawValue)" : column.rawValue
}
''' % ( record_name, )
swift_body += '''}
// MARK: - StringInterpolation
public extension String.StringInterpolation {
mutating func appendInterpolation(%(record_identifier)sColumn column: %(record_name)s.CodingKeys) {
appendLiteral(%(record_name)s.columnName(column))
}
mutating func appendInterpolation(%(record_identifier)sColumnFullyQualified column: %(record_name)s.CodingKeys) {
appendLiteral(%(record_name)s.columnName(column, fullyQualified: true))
}
}
''' % { 'record_identifier': record_identifier(clazz.name), 'record_name': record_name }
swift_body += '''
// MARK: - Deserialization
// TODO: Rework metadata to not include, for example, columns, column indices.
extension %s {
// This method defines how to deserialize a model, given a
// database row. The recordType column is used to determine
// the corresponding model class.
class func fromRecord(_ record: %s) throws -> %s {
''' % ( str(clazz.name), record_name, str(clazz.name), )
swift_body += '''
guard let recordId = record.id else {
throw SDSError.invalidValue
}
switch record.recordType {
'''
deserialize_classes = all_descendents_of_class(clazz) + [clazz]
deserialize_classes.sort(key=lambda value: value.name)
for deserialize_class in deserialize_classes:
if should_ignore_class(deserialize_class):
continue
initializer_params = []
objc_initializer_params = []
objc_super_initializer_args = []
objc_initializer_assigns = []
deserialize_record_type = get_record_type_enum_name(deserialize_class.name)
swift_body += ''' case .%s:
''' % ( str(deserialize_record_type), )
swift_body += '''
let uniqueId: String = record.uniqueId
'''
base_property_names = set()
for property in base_properties:
base_property_names.add(property.name)
deserialize_properties = properties_and_inherited_properties(deserialize_class)
has_local_properties = False
for property in deserialize_properties:
value_name = '%s' % property.name
if property.name not in ( 'uniqueId', ):
did_force_optional = (property.name not in base_property_names) and (not property.is_optional)
for statement in property.deserialize_record_invocation(value_name, did_force_optional):
# print 'statement', statement, type(statement)
swift_body += ' %s\n' % ( str(statement), )
initializer_params.append('%s: %s' % ( str(property.name), value_name, ) )
objc_initializer_type = str(property.objc_type_safe())
if objc_initializer_type.startswith('NSMutable'):
objc_initializer_type = 'NS' + objc_initializer_type[len('NSMutable'):]
if property.is_optional:
objc_initializer_type = 'nullable ' + objc_initializer_type
objc_initializer_params.append('%s:(%s)%s' % ( str(property.name), objc_initializer_type, str(property.name), ) )
is_superclass_property = property.class_name != deserialize_class.name
if is_superclass_property:
objc_super_initializer_args.append('%s:%s' % ( str(property.name), str(property.name), ) )
else:
has_local_properties = True
if str(property.objc_type_safe()).startswith('NSMutableArray'):
objc_initializer_assigns.append('_%s = %s ? [%s mutableCopy] : [NSMutableArray new];' % ( str(property.name), str(property.name), str(property.name), ) )
elif str(property.objc_type_safe()).startswith('NSMutableDictionary'):
objc_initializer_assigns.append('_%s = %s ? [%s mutableCopy] : [NSMutableDictionary new];' % ( str(property.name), str(property.name), str(property.name), ) )
else:
objc_initializer_assigns.append('_%s = %s;' % ( str(property.name), str(property.name), ) )
# --- Initializer Snippets
h_snippet = ''
h_snippet += '''
// clang-format off
- (instancetype)initWithUniqueId:(NSString *)uniqueId
'''
for objc_initializer_param in objc_initializer_params[1:]:
alignment = max(0, len('- (instancetype)initWithUniqueId') - objc_initializer_param.index(':'))
h_snippet += (' ' * alignment) + objc_initializer_param + '\n'
h_snippet += 'NS_SWIFT_NAME(init(%s:));\n' % ':'.join([str(property.name) for property in deserialize_properties])
h_snippet += '''
// clang-format on
'''
m_snippet = ''
m_snippet += '''
// clang-format off
- (instancetype)initWithUniqueId:(NSString *)uniqueId
'''
for objc_initializer_param in objc_initializer_params[1:]:
alignment = max(0, len('- (instancetype)initWithUniqueId') - objc_initializer_param.index(':'))
m_snippet += (' ' * alignment) + objc_initializer_param + '\n'
if len(objc_super_initializer_args) == 1:
suffix = '];'
else:
suffix = ''
m_snippet += '''{
self = [super initWithUniqueId:uniqueId%s
''' % (suffix)
for index, objc_super_initializer_arg in enumerate(objc_super_initializer_args[1:]):
alignment = max(0, len(' self = [super initWithUniqueId') - objc_super_initializer_arg.index(':'))
if index == len(objc_super_initializer_args) - 2:
suffix = '];'
else:
suffix = ''
m_snippet += (' ' * alignment) + objc_super_initializer_arg + suffix + '\n'
m_snippet += '''
if (!self) {
return self;
}
'''
for objc_initializer_assign in objc_initializer_assigns:
m_snippet += (' ' * 4) + objc_initializer_assign + '\n'
if deserialize_class.finalize_method_name is not None:
m_snippet += '''
[self %s];
''' % ( str(deserialize_class.finalize_method_name), )
m_snippet += '''
return self;
}
// clang-format on
'''
# Skip initializer generation for classes without any properties.
if not has_local_properties:
h_snippet = ''
m_snippet = ''
if deserialize_class.filepath.endswith('.m'):
m_filepath = deserialize_class.filepath
h_filepath = m_filepath[:-2] + '.h'
update_objc_snippet(h_filepath, h_snippet)
update_objc_snippet(m_filepath, m_snippet)
swift_body += '''
'''
# --- Invoke Initializer
initializer_invocation = ' return %s(' % str(deserialize_class.name)
swift_body += initializer_invocation
swift_body += (',\n' + ' ' * len(initializer_invocation)).join(initializer_params)
swift_body += ')'
swift_body += '''
'''
# TODO: We could generate a comment with the Obj-C (or Swift) model initializer
# that this deserialization code expects.
swift_body += ''' default:
owsFailDebug("Unexpected record type: \(record.recordType)")
throw SDSError.invalidValue
'''
swift_body += ''' }
'''
swift_body += ''' }
'''
swift_body += '''}
'''
# TODO: Remove the serialization glue below.
if not has_sds_superclass:
swift_body += '''
// MARK: - SDSModel
extension %s: SDSModel {
public var serializer: SDSSerializer {
// Any subclass can be cast to it's superclass,
// so the order of this switch statement matters.
// We need to do a "depth first" search by type.
switch self {''' % str(clazz.name)
for subclass in reversed(all_descendents_of_class(clazz)):
if should_ignore_class(subclass):
continue
swift_body += '''
case let model as %s:
assert(type(of: model) == %s.self)
return %sSerializer(model: model)''' % ( str(subclass.name), str(subclass.name), str(subclass.name), )
swift_body += '''
default:
return %sSerializer(model: self)
}
}
public func asRecord() throws -> SDSRecord {
return try serializer.asRecord()
}
}
''' % ( str(clazz.name), )
if not has_sds_superclass:
swift_body += '''
// MARK: - Table Metadata
extension %sSerializer {
// This defines all of the columns used in the table
// where this model (and any subclasses) are persisted.
static let recordTypeColumn = SDSColumnMetadata(columnName: "recordType", columnType: .int, columnIndex: 0)
static let idColumn = SDSColumnMetadata(columnName: "id", columnType: .primaryKey, columnIndex: 1)
static let uniqueIdColumn = SDSColumnMetadata(columnName: "uniqueId", columnType: .unicodeString, columnIndex: 2)
''' % str(clazz.name)
# Eventually we need a (persistent?) mechanism for guaranteeing
# consistency of column ordering, that is robust to schema
# changes, class hierarchy changes, etc.
column_property_names = []
column_property_names.append('recordType')
column_property_names.append('id')
column_property_names.append('uniqueId')
def write_column_metadata(property, force_optional=False):
column_index = len(column_property_names)
column_name = to_swift_identifier_name(property.name)
column_property_names.append(column_name)
is_optional = property.is_optional or force_optional
optional_split = ', isOptional: true' if is_optional else ''
# print 'property', property.swift_type_safe()
database_column_type = property.database_column_type()
# TODO: Use skipSelect.
return ''' static let %sColumn = SDSColumnMetadata(columnName: "%s", columnType: %s%s, columnIndex: %s)
''' % ( str(column_name), str(column_name), database_column_type, optional_split, str(column_index) )
# If a property has a custom column source, we don't redundantly create a column for that column
base_properties = [property for property in clazz.properties() if not property.has_custom_column_source()]
if len(base_properties) > 0:
swift_body += ' // Base class properties \n'
for property in base_properties:
swift_body += write_column_metadata(property)
# If a property has a custom column source, we don't redundantly create a column for that column
subclass_properties = [property for property in clazz.database_subclass_properties() if not property.has_custom_column_source()]
if len(subclass_properties) > 0:
swift_body += ' // Subclass properties \n'
for property in subclass_properties:
swift_body += write_column_metadata(property, force_optional=True)
database_table_name = 'model_%s' % str(clazz.name)
swift_body += '''
// TODO: We should decide on a naming convention for
// tables that store models.
public static let table = SDSTableMetadata(tableName: "%s", columns: [
''' % database_table_name
for column_property_name in column_property_names:
swift_body += ''' %sColumn,
''' % ( str(column_property_name) )
swift_body += ''' ])
}
'''
# ---- Fetch ----
swift_body += '''
// MARK: - Save/Remove/Update
@objc
public extension %s {
func anyInsert(transaction: SDSAnyWriteTransaction) {
sdsSave(saveMode: .insert, transaction: transaction)
}
// This method is private; we should never use it directly.
// Instead, use anyUpdate(transaction:block:), so that we
// use the "update with" pattern.
private func anyUpdate(transaction: SDSAnyWriteTransaction) {
sdsSave(saveMode: .update, transaction: transaction)
}
@available(*, deprecated, message: "Use anyInsert() or anyUpdate() instead.")
func anyUpsert(transaction: SDSAnyWriteTransaction) {
let isInserting: Bool
if let uniqueId = uniqueId {
if %s.anyFetch(uniqueId: uniqueId, transaction: transaction) != nil {
isInserting = false
} else {
isInserting = true
}
} else {
owsFailDebug("Missing uniqueId: \(type(of:self))")
isInserting = true
}
sdsSave(saveMode: isInserting ? .insert : .update, transaction: transaction)
}
// This method is used by "updateWith..." methods.
//
// This model may be updated from many threads. We don't want to save
// our local copy (this instance) since it may be out of date. We also
// want to avoid re-saving a model that has been deleted. Therefore, we
// use "updateWith..." methods to:
//
// a) Update a property of this instance.
// b) If a copy of this model exists in the database, load an up-to-date copy,
// and update and save that copy.
// b) If a copy of this model _DOES NOT_ exist in the database, do _NOT_ save
// this local instance.
//
// After "updateWith...":
//
// a) Any copy of this model in the database will have been updated.
// b) The local property on this instance will always have been updated.
// c) Other properties on this instance may be out of date.
//
// All mutable properties of this class have been made read-only to
// prevent accidentally modifying them directly.
//
// This isn't a perfect arrangement, but in practice this will prevent
// data loss and will resolve all known issues.
func anyUpdate(transaction: SDSAnyWriteTransaction, block: (%s) -> Void) {
guard let uniqueId = uniqueId else {
owsFailDebug("Missing uniqueId.")
return
}
block(self)
guard let dbCopy = type(of: self).anyFetch(uniqueId: uniqueId,
transaction: transaction) else {
return
}
// Don't apply the block twice to the same instance.
// It's at least unnecessary and actually wrong for some blocks.
// e.g. `block: { $0 in $0.someField++ }`
if dbCopy !== self {
block(dbCopy)
}
dbCopy.anyUpdate(transaction: transaction)
}
func anyRemove(transaction: SDSAnyWriteTransaction) {
anyWillRemove(with: transaction)
switch transaction.writeTransaction {
case .yapWrite(let ydbTransaction):
ydb_remove(with: ydbTransaction)
case .grdbWrite(let grdbTransaction):
do {
let record = try asRecord()
record.sdsRemove(transaction: grdbTransaction)
} catch {
owsFail("Remove failed: \(error)")
}
}
anyDidRemove(with: transaction)
}
func anyReload(transaction: SDSAnyReadTransaction) {
anyReload(transaction: transaction, ignoreMissing: false)
}
func anyReload(transaction: SDSAnyReadTransaction, ignoreMissing: Bool) {
guard let uniqueId = self.uniqueId else {
owsFailDebug("uniqueId was unexpectedly nil")
return
}
guard let latestVersion = type(of: self).anyFetch(uniqueId: uniqueId, transaction: transaction) else {
if !ignoreMissing {
owsFailDebug("`latest` was unexpectedly nil")
}
return
}
setValuesForKeys(latestVersion.dictionaryValue)
}
}
''' % ( ( str(clazz.name), ) * 3 )
# ---- Cursor ----
swift_body += '''
// MARK: - %sCursor
@objc
public class %sCursor: NSObject {
private let cursor: RecordCursor<%s>?
init(cursor: RecordCursor<%s>?) {
self.cursor = cursor
}
public func next() throws -> %s? {
guard let cursor = cursor else {
return nil
}
guard let record = try cursor.next() else {
return nil
}
return try %s.fromRecord(record)
}
public func all() throws -> [%s] {
var result = [%s]()
while true {
guard let model = try next() else {
break
}
result.append(model)
}
return result
}
}
''' % ( str(clazz.name), str(clazz.name), record_name, record_name, str(clazz.name), str(clazz.name), str(clazz.name), str(clazz.name), )
# ---- Fetch ----
swift_body += '''
// MARK: - Obj-C Fetch
// TODO: We may eventually want to define some combination of:
//
// * fetchCursor, fetchOne, fetchAll, etc. (ala GRDB)
// * Optional "where clause" parameters for filtering.
// * Async flavors with completions.
//
// TODO: I've defined flavors that take a read transaction.
// Or we might take a "connection" if we end up having that class.
@objc
public extension %s {
class func grdbFetchCursor(transaction: GRDBReadTransaction) -> %sCursor {
let database = transaction.database
do {
let cursor = try %s.fetchCursor(database)
return %sCursor(cursor: cursor)
} catch {
owsFailDebug("Read failed: \(error)")
return %sCursor(cursor: nil)
}
}
''' % ( str(clazz.name), str(clazz.name), record_name, str(clazz.name), str(clazz.name), )
swift_body += '''
// Fetches a single model by "unique id".
class func anyFetch(uniqueId: String,
transaction: SDSAnyReadTransaction) -> %(class_name)s? {
assert(uniqueId.count > 0)
switch transaction.readTransaction {
case .yapRead(let ydbTransaction):
return %(class_name)s.ydb_fetch(uniqueId: uniqueId, transaction: ydbTransaction)
case .grdbRead(let grdbTransaction):
let sql = "SELECT * FROM \(%(record_name)s.databaseTableName) WHERE \(%(record_identifier)sColumn: .uniqueId) = ?"
return grdbFetchOne(sql: sql, arguments: [uniqueId], transaction: grdbTransaction)
}
}
''' % { "class_name": str(clazz.name), "record_name": record_name, "record_identifier": record_identifier(clazz.name) }
swift_body += '''
// Traverses all records.
// Records are not visited in any particular order.
// Traversal aborts if the visitor returns false.
class func anyEnumerate(transaction: SDSAnyReadTransaction, block: @escaping (%s, UnsafeMutablePointer<ObjCBool>) -> Void) {
switch transaction.readTransaction {
case .yapRead(let ydbTransaction):
%s.ydb_enumerateCollectionObjects(with: ydbTransaction) { (object, stop) in
guard let value = object as? %s else {
owsFailDebug("unexpected object: \(type(of: object))")
return
}
block(value, stop)
}
case .grdbRead(let grdbTransaction):
do {
let cursor = %s.grdbFetchCursor(transaction: grdbTransaction)
var stop: ObjCBool = false
while let value = try cursor.next() {
block(value, &stop)
guard !stop.boolValue else {
break
}
}
} catch let error as NSError {
owsFailDebug("Couldn't fetch models: \(error)")
}
}
}
// Does not order the results.
class func anyFetchAll(transaction: SDSAnyReadTransaction) -> [%s] {
var result = [%s]()
anyEnumerate(transaction: transaction) { (model, _) in
result.append(model)
}
return result
}
''' % ( ( str(clazz.name), ) * 6 )
# ---- Count ----
swift_body += '''
class func anyCount(transaction: SDSAnyReadTransaction) -> UInt {
switch transaction.readTransaction {
case .yapRead(let ydbTransaction):
return ydbTransaction.numberOfKeys(inCollection: %s.collection())
case .grdbRead(let grdbTransaction):
return %s.ows_fetchCount(grdbTransaction.database)
}
}
''' % ( str(clazz.name), record_name, )
# ---- Remove All ----
swift_body += '''
// WARNING: Do not use this method for any models which do cleanup
// in their anyWillRemove(), anyDidRemove() methods.
class func anyRemoveAllWithoutInstantation(transaction: SDSAnyWriteTransaction) {
switch transaction.writeTransaction {
case .yapWrite(let ydbTransaction):
ydbTransaction.removeAllObjects(inCollection: %s.collection())
case .grdbWrite(let grdbTransaction):
do {
try %s.deleteAll(grdbTransaction.database)
} catch {
owsFailDebug("deleteAll() failed: \(error)")
}
}
}
class func anyRemoveAllWithInstantation(transaction: SDSAnyWriteTransaction) {
anyEnumerate(transaction: transaction) { (instance, stop) in
instance.anyRemove(transaction: transaction)
}
}
}
''' % ( str(clazz.name), record_name, )
# ---- Fetch ----
swift_body += '''
// MARK: - Swift Fetch
public extension %s {
class func grdbFetchCursor(sql: String,
arguments: [DatabaseValueConvertible]?,
transaction: GRDBReadTransaction) -> %sCursor {
var statementArguments: StatementArguments?
if let arguments = arguments {
guard let statementArgs = StatementArguments(arguments) else {
owsFailDebug("Could not convert arguments.")
return %sCursor(cursor: nil)
}
statementArguments = statementArgs
}
let database = transaction.database
do {
let statement: SelectStatement = try database.cachedSelectStatement(sql: sql)
let cursor = try %s.fetchCursor(statement, arguments: statementArguments)
return %sCursor(cursor: cursor)
} catch {
Logger.error("sql: \(sql)")
owsFailDebug("Read failed: \(error)")
return %sCursor(cursor: nil)
}
}
''' % ( str(clazz.name), str(clazz.name), str(clazz.name), record_name, str(clazz.name), str(clazz.name), )
string_interpolation_name = remove_prefix_from_class_name(clazz.name)
swift_body += '''
class func grdbFetchOne(sql: String,
arguments: StatementArguments,
transaction: GRDBReadTransaction) -> %s? {
assert(sql.count > 0)
do {
guard let record = try %s.fetchOne(transaction.database, sql: sql, arguments: arguments) else {
return nil
}
return try %s.fromRecord(record)
} catch {
owsFailDebug("error: \(error)")
return nil
}
}
}
''' % ( str(clazz.name), record_name, str(clazz.name), )
# ---- Typed Convenience Methods ----
if has_sds_superclass:
swift_body += '''
// MARK: - Typed Convenience Methods
@objc
public extension %s {
// NOTE: This method will fail if the object has unexpected type.
@objc
func anyUpdate%s(transaction: SDSAnyWriteTransaction, block: (%s) -> Void) {
anyUpdate(transaction: transaction) { (object) in
guard let instance = object as? %s else {
owsFailDebug("Object has unexpected type: \(type(of: object))")
return
}
block(instance)
}
}
}
''' % ( str(clazz.name), str(remove_prefix_from_class_name(clazz.name)), str(clazz.name), str(clazz.name), )
# ---- SDSModel ----
table_superclass = clazz.table_superclass()
table_class_name = str(table_superclass.name)
has_serializable_superclass = table_superclass.name != clazz.name
override_keyword = ''
swift_body += '''
// MARK: - SDSSerializer
// The SDSSerializer protocol specifies how to insert and update the
// row that corresponds to this model.
class %sSerializer: SDSSerializer {
private let model: %s
public required init(model: %s) {
self.model = model
}
''' % ( str(clazz.name), str(clazz.name), str(clazz.name), )
# --- To Record
root_class = clazz.table_superclass()
root_record_name = remove_prefix_from_class_name(root_class.name) + 'Record'
serialize_record_type = get_record_type_enum_name(clazz.name)
swift_body += '''
// MARK: - Record
func asRecord() throws -> SDSRecord {
let id: Int64? = nil
let recordType: SDSRecordType = .%s
guard let uniqueId: String = model.uniqueId else {
owsFailDebug("Missing uniqueId.")
throw SDSError.missingRequiredField
}
''' % ( serialize_record_type, )
initializer_args = ['id', 'recordType', 'uniqueId', ]
# If a property has a custom column source, we don't redundantly create a column for that column
root_base_properties = [property for property in root_class.properties() if not property.has_custom_column_source()]
# If a property has a custom column source, we don't redundantly create a column for that column
root_subclass_properties = [property for property in root_class.database_subclass_properties() if not property.has_custom_column_source()]
root_base_property_names = set()
for property in root_base_properties:
root_base_property_names.add(property.name)
# record_name = remove_prefix_from_class_name(clazz.name) + 'Record'
initializer_value_names = []
for property in properties_and_inherited_properties(clazz):
initializer_value_names.append(property.name)
# print 'initializer_value_names', initializer_value_names
def write_record_property(property, force_optional=False):
column_name = to_swift_identifier_name(property.name)
optional_value = ''
if column_name in initializer_value_names:
did_force_optional = (property.name not in root_base_property_names) and (not property.is_optional)
model_accessor = accessor_name_for_property(property)
value_expr = property.serialize_record_invocation('model.%s' % ( model_accessor, ), did_force_optional)
optional_value = ' = %s' % ( value_expr, )
else:
optional_value = ' = nil'
# print 'property', property.swift_type_safe()
record_field_type = property.record_field_type()
is_optional = property.is_optional or force_optional
optional_split = '?' if is_optional else ''
custom_column_name = custom_column_name_for_property(property)
if custom_column_name is not None:
column_name = custom_column_name
initializer_args.append(str(column_name))
return ''' let %s: %s%s%s
''' % ( str(column_name), record_field_type, optional_split, optional_value, )
if len(root_base_properties) > 0:
swift_body += '\n // Base class properties \n'
for property in root_base_properties:
# print 'base_properties:', property.name
swift_body += write_record_property(property)
if len(root_subclass_properties) > 0:
swift_body += '\n // Subclass properties \n'
for property in root_subclass_properties:
# print 'subclass_properties:', property.name
swift_body += write_record_property(property, force_optional=True)
initializer_args = ['%s: %s' % ( arg, arg, ) for arg in initializer_args]
serialize_record_type = get_record_type_enum_name(clazz.name)
swift_body += '''
return %s(%s)
}
''' % ( root_record_name, ', '.join(initializer_args), )
swift_body += '''}
'''
# print 'swift_body', swift_body
print 'Writing:', swift_filepath
swift_body = sds_common.clean_up_generated_swift(swift_body)
# Add some random whitespace to trigger the auto-formatter.
swift_body = swift_body + (' ' * random.randint(1, 100))
sds_common.write_text_file_if_changed(swift_filepath, swift_body)
def process_class_map(class_map):
print 'processing', class_map
for clazz in class_map.values():
generate_swift_extensions_for_model(clazz)
# ---- Record Type Map
record_type_map = {}
# It's critical that our "record type" values are consistent, even if we add/remove/rename model classes.
# Therefore we persist the mapping of known classes in a JSON file that is under source control.
def update_record_type_map(record_type_swift_path, record_type_json_path):
print 'update_record_type_map'
record_type_map_filepath = record_type_json_path
if os.path.exists(record_type_map_filepath):
with open(record_type_map_filepath, 'rt') as f:
json_string = f.read()
json_data = json.loads(json_string)
record_type_map.update(json_data)
max_record_type = 0
for class_name in record_type_map:
if class_name.startswith('#'):
continue
record_type = record_type_map[class_name]
max_record_type = max(max_record_type, record_type)
for clazz in global_class_map.values():
if clazz.name not in record_type_map:
if not clazz.should_generate_extensions():
continue
max_record_type = int(max_record_type) + 1
record_type = max_record_type
record_type_map[clazz.name] = record_type
record_type_map['#comment'] = 'NOTE: This file is generated by %s. Do not manually edit it, instead run `sds_codegen.sh`.' % ( sds_common.pretty_module_path(__file__), )
json_string = json.dumps(record_type_map, sort_keys=True, indent=4)
sds_common.write_text_file_if_changed(record_type_map_filepath, json_string)
# TODO: We'll need to import SignalServiceKit for non-SSK classes.
swift_body = '''//
// Copyright © 2019 Signal. All rights reserved.
//
import Foundation
import GRDBCipher
import SignalCoreKit
// NOTE: This file is generated by %s.
// Do not manually edit it, instead run `sds_codegen.sh`.
@objc
public enum SDSRecordType: UInt {
''' % ( sds_common.pretty_module_path(__file__), )
for key in sorted(record_type_map.keys()):
if key.startswith('#'):
# Ignore comments
continue
enum_name = get_record_type_enum_name(key)
# print 'enum_name', enum_name
swift_body += ''' case %s = %s
''' % ( str(enum_name), str(record_type_map[key]), )
swift_body += '''}
'''
# print 'swift_body', swift_body
swift_body = sds_common.clean_up_generated_swift(swift_body)
sds_common.write_text_file_if_changed(record_type_swift_path, swift_body)
def get_record_type(clazz):
return record_type_map[clazz.name]
def remove_prefix_from_class_name(class_name):
name = class_name
if name.startswith('TS'):
name = name[len('TS'):]
elif name.startswith('OWS'):
name = name[len('OWS'):]
elif name.startswith('SSK'):
name = name[len('SSK'):]
return name
def get_record_type_enum_name(class_name):
name = remove_prefix_from_class_name(class_name)
if name[0].isnumeric():
name = '_' + name
return to_swift_identifier_name(name)
def record_identifier(class_name):
name = remove_prefix_from_class_name(class_name)
return to_swift_identifier_name(name)
# ---- Column Ordering
column_ordering_map = {}
has_loaded_column_ordering_map = False
# ---- Parsing
enum_type_map = {}
def objc_type_for_enum(enum_name):
if enum_name not in enum_type_map:
print 'enum_type_map', enum_type_map
fail('Enum has unknown type:', enum_name)
enum_type = enum_type_map[enum_name]
return enum_type
def swift_type_for_enum(enum_name):
objc_type = objc_type_for_enum(enum_name)
if objc_type == 'NSInteger':
return 'Int'
elif objc_type == 'NSUInteger':
return 'UInt'
elif objc_type == 'int32_t':
return 'Int32'
elif objc_type == 'unsigned long long':
return 'uint64_t'
elif objc_type == 'unsigned long long':
return 'UInt64'
elif objc_type == 'unsigned long':
return 'UInt64'
elif objc_type == 'unsigned int':
return 'UInt'
else:
fail('Unknown objc type:', objc_type)
def parse_sds_json(file_path):
with open(file_path, 'rt') as f:
json_str = f.read()
json_data = json.loads(json_str)
# print 'json_data:', json_data
classes = json_data['classes']
class_map = {}
for class_dict in classes:
# print 'class_dict:', class_dict
clazz = ParsedClass(class_dict)
class_map[clazz.name] = clazz
enums = json_data['enums']
# print '---- enums', file_path
# print '---- enums', enums
enum_type_map.update(enums)
return class_map
def try_to_parse_file(file_path):
filename = os.path.basename(file_path)
# print 'filename', filename
_, file_extension = os.path.splitext(filename)
if filename.endswith(sds_common.SDS_JSON_FILE_EXTENSION):
# print 'filename:', filename
print '\t', 'found', file_path
return parse_sds_json(file_path)
else:
return {}
def find_sds_intermediary_files_in_path(path):
print 'find_sds_intermediary_files_in_path', path
class_map = {}
if os.path.isfile(path):
class_map.update(try_to_parse_file(path))
else:
for rootdir, dirnames, filenames in os.walk(path):
for filename in filenames:
file_path = os.path.abspath(os.path.join(rootdir, filename))
class_map.update(try_to_parse_file(file_path))
return class_map
def update_subclass_map():
for clazz in global_class_map.values():
if clazz.super_class_name is not None:
subclasses = global_subclass_map.get(clazz.super_class_name, [])
subclasses.append(clazz)
global_subclass_map[clazz.super_class_name] = subclasses
def all_descendents_of_class(clazz):
result = []
# print 'descendents of:', clazz.name
# print '\t', global_subclass_map.get(clazz.name, [])
subclasses = global_subclass_map.get(clazz.name, [])
subclasses.sort(key=lambda value: value.name)
for subclass in subclasses:
result.append(subclass)
result.extend(all_descendents_of_class(subclass))
return result
def is_swift_class_name(swift_type):
return global_class_map.get(swift_type) is not None
# ---- Config JSON
configuration_json = {}
def parse_config_json(config_json_path):
print 'config_json_path', config_json_path
with open(config_json_path, 'rt') as f:
json_str = f.read()
json_data = json.loads(json_str)
global configuration_json
configuration_json = json_data
# We often use nullable NSNumber * for optional numerics (bool, int, int64, double, etc.).
# There's now way to infer which type we're boxing in NSNumber.
# Therefore, we need to specify that in the configuration JSON.
def swift_type_for_nsnumber(property):
nsnumber_types = configuration_json.get('nsnumber_types')
if nsnumber_types is None:
print 'Suggestion: update: %s' % ( str(global_args.config_json_path), )
fail('Configuration JSON is missing mapping for properties of type NSNumber.')
key = property.class_name + '.' + property.name
swift_type = nsnumber_types.get(key)
if swift_type is None:
print 'Suggestion: update: %s' % ( str(global_args.config_json_path), )
fail('Configuration JSON is missing mapping for properties of type NSNumber:', key)
return swift_type
# Some properties shouldn't get serialized.
# For now, there's just one: TSGroupModel.groupImage which is a UIImage.
# We might end up extending the serialization to handle images.
# Or we might store these as Data/NSData/blob.
# TODO:
def should_ignore_property(property):
properties_to_ignore = configuration_json.get('properties_to_ignore')
if properties_to_ignore is None:
fail('Configuration JSON is missing list of properties to ignore during serialization.')
key = property.class_name + '.' + property.name
return key in properties_to_ignore
def custom_property_column_source(property):
custom_names = configuration_json.get('custom_property_column_sources')
if custom_names is None:
fail('Configuration JSON is missing dict of custom_property_column_sources during serialization.')
key = property.class_name + '.' + property.name
return custom_names.get(key)
def should_ignore_class(clazz):
class_to_skip_serialization = configuration_json.get('class_to_skip_serialization')
if class_to_skip_serialization is None:
fail('Configuration JSON is missing list of classes to ignore during serialization.')
if clazz.name in class_to_skip_serialization:
return True
if clazz.super_class_name is None:
return False
if not clazz.super_class_name in global_class_map:
return False
super_clazz = global_class_map[clazz.super_class_name]
return should_ignore_class(super_clazz)
def is_flagged_as_enum_property(property):
enum_properties = configuration_json.get('enum_properties')
if enum_properties is None:
fail('Configuration JSON is missing list of properties to treat as enums.')
key = property.class_name + '.' + property.name
return key in enum_properties
def accessor_name_for_property(property):
custom_accessors = configuration_json.get('custom_accessors')
if custom_accessors is None:
fail('Configuration JSON is missing list of custom property accessors.')
key = property.class_name + '.' + property.name
# print '--?--', key, custom_accessors.get(key, property.name)
return custom_accessors.get(key, property.name)
def custom_column_name_for_property(property):
custom_column_names = configuration_json.get('custom_column_names')
if custom_column_names is None:
fail('Configuration JSON is missing list of custom column names.')
key = property.class_name + '.' + property.name
# print '--?--', key, custom_accessors.get(key, property.name)
return custom_column_names.get(key)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Parse Swift AST.')
parser.add_argument('--src-path', required=True, help='used to specify a path to process.')
parser.add_argument('--search-path', required=True, help='used to specify a path to process.')
parser.add_argument('--record-type-swift-path', required=True, help='path of the record type enum swift file.')
parser.add_argument('--record-type-json-path', required=True, help='path of the record type map json file.')
parser.add_argument('--config-json-path', required=True, help='path of the json file with code generation config info.')
args = parser.parse_args()
global_args = args
src_path = os.path.abspath(args.src_path)
search_path = os.path.abspath(args.search_path)
record_type_swift_path = os.path.abspath(args.record_type_swift_path)
record_type_json_path = os.path.abspath(args.record_type_json_path)
config_json_path = os.path.abspath(args.config_json_path)
# We control the code generation process using a JSON config file.
print
print 'Parsing Config'
parse_config_json(config_json_path)
# The code generation needs to understand the class hierarchy so that
# it can:
#
# * Define table schemas that include the superset of properties in
# the model class hierarchies.
# * Generate deserialization methods that handle all subclasses.
# * etc.
print
print 'Parsing Global Class Map'
global_class_map.update(find_sds_intermediary_files_in_path(search_path))
print 'global_class_map', global_class_map
update_subclass_map()
print
print 'Parsing Record Type Map'
update_record_type_map(record_type_swift_path, record_type_json_path)
print
print 'Processing'
process_class_map(find_sds_intermediary_files_in_path(src_path))