disabled sparse option in osce export script

This commit is contained in:
Jan Buethe 2024-02-15 15:25:06 +01:00
parent ffd1b0b137
commit 735117b6d7
No known key found for this signature in database
GPG Key ID: 9E32027A35B36314

View File

@ -54,14 +54,14 @@ parser.add_argument('checkpoint', type=str, help='LACE or NoLACE model checkpoin
parser.add_argument('output_dir', type=str, help='output folder')
parser.add_argument('--quantize', action="store_true", help='quantization according to schedule')
sparse_default=False
schedules = {
'nolace': [
('pitch_embedding', dict()),
('feature_net.conv1', dict()),
('feature_net.conv2', dict(quantize=True, scale=None, sparse=True)),
('feature_net.tconv', dict(quantize=True, scale=None, sparse=True)),
('feature_net.gru', dict(quantize=True, scale=None, recurrent_scale=None, input_sparse=True, recurrent_sparse=True)),
('feature_net.conv2', dict(quantize=True, scale=None, sparse=sparse_default)),
('feature_net.tconv', dict(quantize=True, scale=None, sparse=sparse_default)),
('feature_net.gru', dict(quantize=True, scale=None, recurrent_scale=None, input_sparse=sparse_default, recurrent_sparse=sparse_default)),
('cf1', dict(quantize=True, scale=None)),
('cf2', dict(quantize=True, scale=None)),
('af1', dict(quantize=True, scale=None)),
@ -71,18 +71,18 @@ schedules = {
('af2', dict(quantize=True, scale=None)),
('af3', dict(quantize=True, scale=None)),
('af4', dict(quantize=True, scale=None)),
('post_cf1', dict(quantize=True, scale=None, sparse=True)),
('post_cf2', dict(quantize=True, scale=None, sparse=True)),
('post_af1', dict(quantize=True, scale=None, sparse=True)),
('post_af2', dict(quantize=True, scale=None, sparse=True)),
('post_af3', dict(quantize=True, scale=None, sparse=True))
('post_cf1', dict(quantize=True, scale=None, sparse=sparse_default)),
('post_cf2', dict(quantize=True, scale=None, sparse=sparse_default)),
('post_af1', dict(quantize=True, scale=None, sparse=sparse_default)),
('post_af2', dict(quantize=True, scale=None, sparse=sparse_default)),
('post_af3', dict(quantize=True, scale=None, sparse=sparse_default))
],
'lace' : [
('pitch_embedding', dict()),
('feature_net.conv1', dict()),
('feature_net.conv2', dict(quantize=True, scale=None, sparse=True)),
('feature_net.tconv', dict(quantize=True, scale=None, sparse=True)),
('feature_net.gru', dict(quantize=True, scale=None, recurrent_scale=None, input_sparse=True, recurrent_sparse=True)),
('feature_net.conv2', dict(quantize=True, scale=None, sparse=sparse_default)),
('feature_net.tconv', dict(quantize=True, scale=None, sparse=sparse_default)),
('feature_net.gru', dict(quantize=True, scale=None, recurrent_scale=None, input_sparse=sparse_default, recurrent_sparse=sparse_default)),
('cf1', dict(quantize=True, scale=None)),
('cf2', dict(quantize=True, scale=None)),
('af1', dict(quantize=True, scale=None))