zkgroup/codegen/codegen_java.py
Ehren Kret bbd71fd1fb
Some checks failed
Code Formatting / Code Formatting (push) Has been cancelled
Tests / Rust + Java + Node (push) Has been cancelled
Tests / iOS (push) Has been cancelled
Release artifacts / Build Linux + Android (push) Has been cancelled
Release artifacts / Build Windows (push) Has been cancelled
Release artifacts / Build macOS (Node) (push) Has been cancelled
Add three new accessor methods for the presentation
2021-09-29 15:26:37 -05:00

764 lines
27 KiB
Python

template_class = \
"""//
// Copyright (C) 2020 Signal Messenger, LLC.
// All rights reserved.
//
// SPDX-License-Identifier: GPL-3.0-only
//
// Generated by zkgroup/codegen/codegen.py - do not edit
package org.signal.zkgroup%(dir_section)s;
%(imports)s
public final class %(class_name)s extends ByteArray {
public static final int SIZE = %(size)s;
%(static_methods)s%(constructors)s
%(methods)s%(serialize_method)s}
"""
template_constructor = \
"""
%(constructor_access)s %(class_name)s(%(constructor_contents_type)s contents) %(constructor_exception_decl)s {
super(%(constructor_contents)s, SIZE%(runtime_error_bool)s);%(check_valid_contents)s
}
"""
template_constructor_for_string_contents = \
"""
%(constructor_access)s %(class_name)s(%(constructor_contents_type)s contents) %(constructor_exception_decl)s {
super(%(constructor_contents)s, SIZE%(runtime_error_bool)s);%(check_valid_contents)s
}
"""
serialize_method_string = \
"""
public String serialize() {
try {
return new String(contents, "UTF-8");
} catch (UnsupportedEncodingException e) {
throw new AssertionError();
}
}
"""
serialize_method_binary = \
"""
public byte[] serialize() {
return contents.clone();
}
"""
template_wrapping_class = \
"""//
// Copyright (C) 2020 Signal Messenger, LLC.
// All rights reserved.
//
// SPDX-License-Identifier: GPL-3.0-only
//
// Generated by zkgroup/codegen/codegen.py - do not edit
package org.signal.zkgroup%(dir_section)s;
%(imports)s
public class %(class_name)s {
private final %(wrapped_class_type)s %(wrapped_class_var)s;
%(static_methods)s
public %(class_name)s(%(wrapped_class_type)s %(wrapped_class_var)s) {
this.%(wrapped_class_var)s = %(wrapped_class_var)s;
}
%(methods)s
}
"""
template_check_valid_contents_constructor = \
"""
int ffi_return = Native.%(class_name_lower_camel)sCheckValidContentsJNI(contents);
if (ffi_return == Native.FFI_RETURN_INPUT_ERROR) {
throw new InvalidInputException("FFI_RETURN_INPUT_ERROR");
}
if (ffi_return != Native.FFI_RETURN_OK) {
throw new ZkGroupError("FFI_RETURN!=OK");
}"""
template_check_valid_contents_constructor_runtime_error = \
"""
int ffi_return = Native.%(class_name_lower_camel)sCheckValidContentsJNI(contents);
if (ffi_return == Native.FFI_RETURN_INPUT_ERROR) {
throw new IllegalArgumentException(new InvalidInputException("FFI_RETURN_INPUT_ERROR"));
}
if (ffi_return != Native.FFI_RETURN_OK) {
throw new ZkGroupError("FFI_RETURN!=OK");
}"""
template_static_method = \
"""
%(access)s static %(return_name)s %(method_name)s(%(param_decls)s) %(exception_decl)s{
byte[] newContents = new byte[%(return_name)s.SIZE];%(get_rand)s
int ffi_return = Native.%(jni_method_name)s(%(param_args)snewContents);%(exception_check)s
if (ffi_return != Native.FFI_RETURN_OK) {
throw new ZkGroupError("FFI_RETURN!=OK");
}
try {
return new %(return_name)s(newContents);
} catch (InvalidInputException e) {
throw new AssertionError(e);
}
}
"""
template_static_method_retval_runtime_error_on_serialize = \
"""
%(access)s static %(return_name)s %(method_name)s(%(param_decls)s) %(exception_decl)s{
byte[] newContents = new byte[%(return_name)s.SIZE];%(get_rand)s
int ffi_return = Native.%(jni_method_name)s(%(param_args)snewContents);%(exception_check)s
if (ffi_return != Native.FFI_RETURN_OK) {
throw new ZkGroupError("FFI_RETURN!=OK");
}
try {
return new %(return_name)s(newContents);
} catch (IllegalArgumentException e) {
throw new AssertionError(e);
}
}
"""
template_static_method_rand_wrapper = \
"""
%(access)s static %(return_name)s %(method_name)s(%(param_decls)s) %(exception_decl)s{
return %(full_method_name)s(%(param_args)s);
}
"""
template_method = \
"""
%(access)s %(return_name)s %(method_name)s(%(param_decls)s) %(exception_decl)s{
byte[] newContents = new byte[%(return_len)s];%(get_rand)s
int ffi_return = Native.%(jni_method_name)s(%(contents)s, %(param_args)snewContents);%(exception_check)s
if (ffi_return != Native.FFI_RETURN_OK) {
throw new ZkGroupError("FFI_RETURN!=OK");
}
try {
return new %(return_name)s(newContents);
} catch (InvalidInputException e) {
throw new AssertionError(e);
}
}
"""
template_method_retval_runtime_error_on_serialize = \
"""
%(access)s %(return_name)s %(method_name)s(%(param_decls)s) %(exception_decl)s{
byte[] newContents = new byte[%(return_len)s];%(get_rand)s
int ffi_return = Native.%(jni_method_name)s(%(contents)s, %(param_args)snewContents);%(exception_check)s
if (ffi_return != Native.FFI_RETURN_OK) {
throw new ZkGroupError("FFI_RETURN!=OK");
}
return new %(return_name)s(newContents);
}
"""
template_method_bool = \
"""
%(access)s void %(method_name)s(%(param_decls)s) %(exception_decl)s{
int ffi_return = Native.%(jni_method_name)s(%(contents)s, %(param_args)s);%(exception_check)s
if (ffi_return != Native.FFI_RETURN_OK) {
throw new ZkGroupError("FFI_RETURN!=OK");
}
}
"""
template_method_uuid = \
"""
%(access)s UUID %(method_name)s(%(param_decls)s) %(exception_decl)s{
byte[] newContents = new byte[%(return_len)s];
int ffi_return = Native.%(jni_method_name)s(%(contents)s, %(param_args)snewContents);%(exception_check)s
if (ffi_return != Native.FFI_RETURN_OK) {
throw new ZkGroupError("FFI_RETURN!=OK");
}
return UUIDUtil.deserialize(newContents);
}
"""
template_method_bytearray = \
"""
%(access)s byte[] %(method_name)s(%(param_decls)s) %(exception_decl)s{
byte[] newContents = new byte[%(return_len)s];%(get_rand)s
int ffi_return = Native.%(jni_method_name)s(%(contents)s, %(param_args)snewContents);%(exception_check)s
if (ffi_return != Native.FFI_RETURN_OK) {
throw new ZkGroupError("FFI_RETURN!=OK");
}
return newContents;
}
"""
template_method_int = \
"""
%(access)s int %(method_name)s(%(param_decls)s) %(exception_decl)s{
byte[] newContents = new byte[4];
int ffi_return = Native.%(jni_method_name)s(%(contents)s, %(param_args)snewContents);%(exception_check)s
if (ffi_return != Native.FFI_RETURN_OK) {
throw new ZkGroupError("FFI_RETURN!=OK");
}
return ByteBuffer.wrap(newContents).getInt();
}
"""
template_method_long = \
"""
%(access)s long %(method_name)s(%(param_decls)s) %(exception_decl)s{
byte[] newContents = new byte[8];
int ffi_return = Native.%(jni_method_name)s(%(contents)s, %(param_args)snewContents);%(exception_check)s
if (ffi_return != Native.FFI_RETURN_OK) {
throw new ZkGroupError("FFI_RETURN!=OK");
}
return ByteBuffer.wrap(newContents).getLong();
}
"""
template_method_rand_wrapper = \
"""
public %(return_name)s %(method_name)s(%(param_decls)s) %(exception_decl)s{
return %(full_method_name)s(%(param_args)s);
}
"""
template_native = \
"""//
// Copyright (C) 2020 Signal Messenger, LLC.
// All rights reserved.
//
// SPDX-License-Identifier: GPL-3.0-only
//
// Generated by zkgroup/codegen/codegen.py - do not edit
package org.signal.zkgroup.internal;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.file.Files;
public final class Native {
public static final int FFI_RETURN_OK = 0;
public static final int FFI_RETURN_INTERNAL_ERROR = 1; // ZkGroupError
public static final int FFI_RETURN_INPUT_ERROR = 2;
public static final int RANDOM_LENGTH = 32;
static {
try {
String osName = System.getProperty("os.name").toLowerCase(java.util.Locale.ROOT);
boolean isMacOs = osName.startsWith("mac os x");
String extension = isMacOs ? ".dylib" : ".so";
try (InputStream in = Native.class.getResourceAsStream("/libzkgroup" + extension)) {
if (in != null) {
copyToTempFileAndLoad(in, extension);
} else {
System.loadLibrary("zkgroup");
}
}
} catch (Exception e) {
throw new RuntimeException(e);
}
}
private Native() {
}
private static void copyToTempFileAndLoad(InputStream in, String extension) throws IOException {
File tempFile = Files.createTempFile("resource", extension).toFile();
tempFile.deleteOnExit();
try (OutputStream out = new FileOutputStream(tempFile)) {
copy(in, out);
}
System.load(tempFile.getAbsolutePath());
}
"""
template_native_end = \
"""
private static void copy(InputStream in, OutputStream out) throws IOException {
byte[] buffer = new byte[4096];
int read;
while ((read = in.read(buffer)) != -1) {
out.write(buffer, 0, read);
}
}
}
"""
native_string = template_native
def add_import(import_strings, class_dir_dict, my_dir_name, class_name):
dir_name = class_dir_dict[class_name.snake()].snake()
if len(dir_name)==0 and len(my_dir_name.snake()) == 0:
return
elif dir_name == my_dir_name.snake():
return
if dir_name:
import_strings.append("import org.signal.zkgroup.%s.%s;" % (dir_name, class_name.camel()))
else:
import_strings.append("import org.signal.zkgroup.%s;" % (class_name.camel()))
def get_decls(params, import_strings, class_dir_dict, my_dir_name):
s = ""
for param in params:
if param[1].snake() == "randomness":
s += "SecureRandom secureRandom, "
import_strings.append("import java.security.SecureRandom;")
elif param[0] == "class":
s += param[1].camel() + " " + param[1].lower_camel() + ", "
add_import(import_strings, class_dir_dict, my_dir_name, param[1])
else:
s += param[0] + " " + param[1].lower_camel() + ", "
if len(s) != 0:
s = s[:-2]
return s
def get_rand_wrapper_decls(params):
s = ""
for param in params:
if param[1].snake() != "randomness":
if param[0] == "class":
s += param[1].camel() + " " + param[1].lower_camel() + ", "
else:
s += param[0] + " " + param[1].lower_camel() + ", "
if len(s) != 0:
s = s[:-2]
return s
def get_args(params, import_strings, commaAtEnd):
s = ""
for param in params:
if param[0] == "byte[]" or param[0] == "int" or param[0] == "long":
s += param[1].lower_camel() + ", "
elif param[0] == "UUID":
s += "UUIDUtil.serialize(" + param[1].lower_camel() + "), "
import_strings.append("import java.util.UUID;")
import_strings.append("import org.signal.zkgroup.util.UUIDUtil;")
elif param[1].snake() == "randomness":
s += "random, "
else:
s += param[1].lower_camel() + ".getInternalContentsForJNI(), "
if len(s) != 0 and not commaAtEnd:
s = s[:-2]
return s
def get_jni_arg_decls(params, selfBool, commaAtEndBool):
s = ""
if selfBool:
s += "byte[] self, "
counter = 0
for param in params:
if param[0] == "int":
s += f"int {param[1].lower_camel()}, "
elif param[0] == "long":
s += f"long {param[1].lower_camel()}, "
else:
s += f"byte[] {param[1].lower_camel()}, "
counter += 1
if len(s) != 0 and not commaAtEndBool:
s = s[:-2]
if commaAtEndBool:
s += "byte[] output"
return s
def get_rand_wrapper_args(params, commaAtEnd):
s = ""
for param in params:
if param[0] == "byte[]":
s += param[1].lower_camel() + ", "
elif param[1].snake() == "randomness":
s += "new SecureRandom(), "
else:
s += param[1].lower_camel() + ", "
if len(s) != 0 and not commaAtEnd:
s = s[:-2]
return s
def append_jni_function_decl(jni_method_name, params, selfBool, commaAtEndBool):
global native_string
native_string += " public static native int " + jni_method_name + "(%s);\n" % get_jni_arg_decls(params, selfBool, commaAtEndBool)
def append_jni_check_valid_contents(jni_method_name):
global native_string
native_string += " public static native int " + jni_method_name + "(byte[] self);\n"
def print_class(c, runtime_error_on_serialize_dict, class_dir_dict):
static_methods_string = ""
if len(c.methods) == 0 and len(c.static_methods) == 0:
import_strings = []
else:
import_strings = [\
"import org.signal.zkgroup.internal.Native;",
]
global native_string
my_dir_name = c.dir_name
if c.wrap_class == None:
contents = "contents"
else:
contents = c.wrap_class.lower_camel() + ".getInternalContentsForJNI()"
add_import(import_strings, class_dir_dict, my_dir_name, c.wrap_class)
for method in c.static_methods:
exception_decl = ""
exception_check =""
if len(method.params) > 1 or (len(method.params) == 1 and not method.method_name.snake().endswith("_deterministic")):
if method.runtime_error == False:
if method.verification == False:
if my_dir_name.snake() != "":
import_strings.append("import org.signal.zkgroup.VerificationFailedException;")
exception_decl = "throws VerificationFailedException "
exception_check ="""\n if (ffi_return == Native.FFI_RETURN_INPUT_ERROR) {
throw new VerificationFailedException();
}"""
else:
exception_decl = "throws VerificationFailedException "
exception_check ="""\n if (ffi_return == Native.FFI_RETURN_VERIFICATION_FAILED) {
throw new VerificationFailedException();
}"""
access = "public"
method_name = method.method_name.lower_camel()
get_rand = ""
if method.method_name.snake().endswith("_deterministic"):
import_strings.append("import java.security.SecureRandom;")
method_name = method.method_name.lower_camel()[:-len("Deterministic")]
param_args = get_rand_wrapper_args(method.params, False)
get_rand = """\n byte[] random = new byte[Native.RANDOM_LENGTH];
secureRandom.nextBytes(random);"""
static_methods_string += template_static_method_rand_wrapper % {
"method_name": method_name,
"return_name": method.return_name.camel(),
"full_method_name": method_name,
"param_decls": get_rand_wrapper_decls(method.params),
"param_args": param_args,
"access": access,
"exception_decl": exception_decl,
"exception_check": exception_check,
}
param_args = get_args(method.params, import_strings, True)
if c.wrap_class == None:
jni_method_name = c.class_name.lower_camel() + method.method_name.camel() + "JNI"
else:
jni_method_name = c.wrap_class.lower_camel() + method.method_name.camel() + "JNI"
append_jni_function_decl(jni_method_name, method.params, False, True)
if runtime_error_on_serialize_dict[method.return_name.snake()]:
template = template_static_method_retval_runtime_error_on_serialize
import_strings.append("import org.signal.zkgroup.ZkGroupError;")
else:
template = template_static_method
if my_dir_name.snake() != "":
import_strings.append("import org.signal.zkgroup.InvalidInputException;")
static_methods_string += template % {
"method_name": method_name,
"return_name": method.return_name.camel(),
"return_len": c.class_len,
"param_decls": get_decls(method.params, import_strings, class_dir_dict, my_dir_name),
"param_args": param_args,
"jni_method_name": jni_method_name,
"access": access,
"exception_decl": exception_decl,
"exception_check": exception_check,
"get_rand": get_rand,
}
methods_string = ""
for method in c.methods:
if method.method_name.snake() == "check_valid_contents":
continue
exception_decl = ""
exception_check =""
if len(method.params) != 0:
if method.runtime_error == False:
if method.verification == False:
if my_dir_name.snake() != "":
import_strings.append("import org.signal.zkgroup.VerificationFailedException;")
exception_decl = "throws VerificationFailedException "
exception_check ="""\n if (ffi_return == Native.FFI_RETURN_INPUT_ERROR) {
throw new VerificationFailedException();
}"""
else:
exception_decl = "throws VerificationFailedException "
exception_check ="""\n if (ffi_return == Native.FFI_RETURN_VERIFICATION_FAILED) {
throw new VerificationFailedException();
}"""
access = "public"
method_name = method.method_name.lower_camel()
get_rand = ""
if method.method_name.snake().endswith("_deterministic"):
import_strings.append("import java.security.SecureRandom;")
method_name = method.method_name.lower_camel()[:-len("Deterministic")]
param_args = get_rand_wrapper_args(method.params, False)
get_rand = """\n byte[] random = new byte[Native.RANDOM_LENGTH];
secureRandom.nextBytes(random);"""
# Maybe add more cases to support different return types
if method.return_type == "byte[]":
return_name = "byte[]"
else:
return_name = method.return_name.camel()
methods_string += template_method_rand_wrapper % {
"contents": contents,
"method_name": method_name,
"return_name": return_name,
"full_method_name": method_name,
"param_decls": get_rand_wrapper_decls(method.params),
"param_args": param_args,
"access": access,
"exception_decl": exception_decl,
"exception_check": exception_check,
}
if c.wrap_class == None:
jni_method_name = c.class_name.lower_camel() + method.method_name.camel() + "JNI"
else:
jni_method_name = c.wrap_class.lower_camel() + method.method_name.camel() + "JNI"
return_len = None
if method.return_type == "boolean":
template = template_method_bool
param_args = get_args(method.params, import_strings, False)
append_jni_function_decl(jni_method_name, method.params, True, False)
elif method.return_type == "int":
template = template_method_int
param_args = get_args(method.params, import_strings, False)
append_jni_function_decl(jni_method_name, method.params, True, True)
import_strings += "import java.nio.ByteBuffer;",
elif method.return_type == "long":
template = template_method_long
param_args = get_args(method.params, import_strings, False)
append_jni_function_decl(jni_method_name, method.params, True, True)
import_strings += "import java.nio.ByteBuffer;",
elif method.return_type == "UUID":
import_strings.append("import java.util.UUID;")
template = template_method_uuid
param_args = get_args(method.params, import_strings, True)
append_jni_function_decl(jni_method_name, method.params, True, True)
elif method.return_type == "byte[]": # copied from UUID?
template = template_method_bytearray
param_args = get_args(method.params, import_strings, True)
append_jni_function_decl(jni_method_name, method.params, True, True)
if method.relative_return_size is not None:
return_len = f"{method.params[method.relative_return_size][1].lower_camel()}.length + {method.return_size_increment}"
else:
add_import(import_strings, class_dir_dict, my_dir_name, method.return_name)
if runtime_error_on_serialize_dict[method.return_name.snake()]:
template = template_method_retval_runtime_error_on_serialize
else:
template = template_method
import_strings.append("import org.signal.zkgroup.ZkGroupError;")
param_args = get_args(method.params, import_strings, True)
append_jni_function_decl(jni_method_name, method.params, True, True)
if method.return_name.snake() == "uuid":
return_len = "UUIDUtil.UUID_LENGTH"
elif return_len is None:
return_len = method.return_name.camel() + ".SIZE"
methods_string += template % {
"contents": contents,
"method_name": method_name,
"return_name": method.return_name.camel(),
"return_len": return_len,
"param_decls": get_decls(method.params, import_strings, class_dir_dict, my_dir_name),
"param_args": param_args,
"jni_method_name": jni_method_name,
"access": access,
"exception_decl": exception_decl,
"exception_check": exception_check,
"get_rand": get_rand,
}
if c.dir_name.snake() != "":
dir_section = "." + c.dir_name.snake()
else:
dir_section = ""
constructor_exception_decl = "throws InvalidInputException" # overwritten in needed
runtime_error_bool = ""
if c.check_valid_contents:
if c.runtime_error_on_serialize:
constructor_exception_decl = "" # overwritten in needed
runtime_error_bool = ", true"
check_valid_contents = template_check_valid_contents_constructor_runtime_error % {
"class_name_lower_camel": c.class_name.lower_camel(),
}
jni_method_name = c.class_name.lower_camel() + "CheckValidContentsJNI"
append_jni_check_valid_contents(jni_method_name)
if my_dir_name.snake() != "":
import_strings.append("import org.signal.zkgroup.ZkGroupError;")
import_strings.append("import org.signal.zkgroup.InvalidInputException;")
else:
constructor_exception_decl = "throws InvalidInputException" # overwritten in needed
check_valid_contents = template_check_valid_contents_constructor % {
"class_name_lower_camel": c.class_name.lower_camel(),
}
jni_method_name = c.class_name.lower_camel() + "CheckValidContentsJNI"
append_jni_check_valid_contents(jni_method_name)
if my_dir_name.snake() != "":
import_strings.append("import org.signal.zkgroup.InvalidInputException;")
import_strings.append("import org.signal.zkgroup.ZkGroupError;")
else:
check_valid_contents = ""
if constructor_exception_decl.endswith("InvalidInputException") and my_dir_name.snake() != "":
import_strings.append("import org.signal.zkgroup.InvalidInputException;")
if c.no_serialize:
constructor_access = "private"
else:
constructor_access = "public"
if c.wrap_class == None:
import_strings.append("import org.signal.zkgroup.internal.ByteArray;")
import_strings = list(set(import_strings))
import_strings.sort()
# constructors
constructors_string = ""
constructor_contents = "contents"
constructor_contents_type = "byte[]"
constructors_string += template_constructor % {
"class_name": c.class_name.camel(),
"constructor_contents": constructor_contents,
"constructor_contents_type": constructor_contents_type,
"constructor_access": constructor_access,
"constructor_exception_decl": constructor_exception_decl,
"runtime_error_bool": runtime_error_bool,
"check_valid_contents": check_valid_contents,
}
if c.string_contents == False:
serialize_method = serialize_method_binary
else:
constructor_contents = 'contents.getBytes("UTF-8")'
constructor_contents_type = "String"
import_strings.append("import java.io.UnsupportedEncodingException;")
serialize_method = serialize_method_string
constructors_string += template_constructor_for_string_contents % {
"class_name": c.class_name.camel(),
"constructor_contents": constructor_contents,
"constructor_contents_type": constructor_contents_type,
"constructor_access": constructor_access,
"constructor_exception_decl": constructor_exception_decl + ", UnsupportedEncodingException",
"runtime_error_bool": runtime_error_bool,
"check_valid_contents": check_valid_contents,
}
constructors_string = constructors_string[:-1]
if c.wrap_class != None:
class_string = template_wrapping_class % {
"imports": "\n".join(import_strings),
"wrapped_class_type": c.wrap_class.camel(),
"wrapped_class_var": c.wrap_class.lower_camel(),
"dir_section": dir_section,
"class_name": c.class_name.camel(),
"size": c.class_len_int,
"constructors": constructors_string,
"static_methods": static_methods_string,
"methods": methods_string,
"serialize_method": serialize_method
}
else:
class_string = template_class % {
"imports": "\n".join(import_strings),
"dir_section": dir_section,
"class_name": c.class_name.camel(),
"size": c.class_len_int,
"constructors": constructors_string,
"static_methods": static_methods_string,
"methods": methods_string,
"serialize_method": serialize_method
}
return class_string
def produce_output(classes):
runtime_error_on_serialize_dict = {}
class_dir_dict = {}
for c in classes:
runtime_error_on_serialize_dict[c.class_name.snake()] = c.runtime_error_on_serialize
class_dir_dict[c.class_name.snake()] = c.dir_name
for c in classes:
if c.no_class:
continue
if c.dir_name.snake() != "":
f = open("java/%s/%s.java" % (c.dir_name.snake(), c.class_name.camel()), "w")
else:
f = open("java/%s.java" % c.class_name.camel(), "w")
f.write(print_class(c, runtime_error_on_serialize_dict, class_dir_dict))
f.close()
f = open("java/internal/Native.java", "w")
global native_string
native_string += template_native_end
f.write(native_string)
f.close()