Fix issue with prekey update
This commit is contained in:
		
							parent
							
								
									7206b4da25
								
							
						
					
					
						commit
						2c0ad7feb7
					
				@ -90,7 +90,11 @@ public class AccountHelper {
 | 
				
			|||||||
        }
 | 
					        }
 | 
				
			||||||
        try {
 | 
					        try {
 | 
				
			||||||
            updateAccountAttributes();
 | 
					            updateAccountAttributes();
 | 
				
			||||||
            context.getPreKeyHelper().refreshPreKeysIfNecessary();
 | 
					            if (account.getPreviousStorageVersion() < 9) {
 | 
				
			||||||
 | 
					                context.getPreKeyHelper().forceRefreshPreKeys();
 | 
				
			||||||
 | 
					            } else {
 | 
				
			||||||
 | 
					                context.getPreKeyHelper().refreshPreKeysIfNecessary();
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
            if (account.getAci() == null || account.getPni() == null) {
 | 
					            if (account.getAci() == null || account.getPni() == null) {
 | 
				
			||||||
                checkWhoAmiI();
 | 
					                checkWhoAmiI();
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
 | 
				
			|||||||
@ -41,6 +41,11 @@ public class PreKeyHelper {
 | 
				
			|||||||
        refreshPreKeysIfNecessary(ServiceIdType.PNI);
 | 
					        refreshPreKeysIfNecessary(ServiceIdType.PNI);
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    public void forceRefreshPreKeys() throws IOException {
 | 
				
			||||||
 | 
					        forceRefreshPreKeys(ServiceIdType.ACI);
 | 
				
			||||||
 | 
					        forceRefreshPreKeys(ServiceIdType.PNI);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    public void refreshPreKeysIfNecessary(ServiceIdType serviceIdType) throws IOException {
 | 
					    public void refreshPreKeysIfNecessary(ServiceIdType serviceIdType) throws IOException {
 | 
				
			||||||
        final var identityKeyPair = account.getIdentityKeyPair(serviceIdType);
 | 
					        final var identityKeyPair = account.getIdentityKeyPair(serviceIdType);
 | 
				
			||||||
        if (identityKeyPair == null) {
 | 
					        if (identityKeyPair == null) {
 | 
				
			||||||
@ -56,6 +61,22 @@ public class PreKeyHelper {
 | 
				
			|||||||
        }
 | 
					        }
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    public void forceRefreshPreKeys(ServiceIdType serviceIdType) throws IOException {
 | 
				
			||||||
 | 
					        final var identityKeyPair = account.getIdentityKeyPair(serviceIdType);
 | 
				
			||||||
 | 
					        if (identityKeyPair == null) {
 | 
				
			||||||
 | 
					            return;
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        final var accountId = account.getAccountId(serviceIdType);
 | 
				
			||||||
 | 
					        if (accountId == null) {
 | 
				
			||||||
 | 
					            return;
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        final var counts = new OneTimePreKeyCounts(0, 0);
 | 
				
			||||||
 | 
					        if (refreshPreKeysIfNecessary(serviceIdType, identityKeyPair, counts, true)) {
 | 
				
			||||||
 | 
					            refreshPreKeysIfNecessary(serviceIdType, identityKeyPair, counts, true);
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    private boolean refreshPreKeysIfNecessary(
 | 
					    private boolean refreshPreKeysIfNecessary(
 | 
				
			||||||
            final ServiceIdType serviceIdType, final IdentityKeyPair identityKeyPair
 | 
					            final ServiceIdType serviceIdType, final IdentityKeyPair identityKeyPair
 | 
				
			||||||
    ) throws IOException {
 | 
					    ) throws IOException {
 | 
				
			||||||
@ -67,8 +88,17 @@ public class PreKeyHelper {
 | 
				
			|||||||
            preKeyCounts = new OneTimePreKeyCounts(0, 0);
 | 
					            preKeyCounts = new OneTimePreKeyCounts(0, 0);
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return refreshPreKeysIfNecessary(serviceIdType, identityKeyPair, preKeyCounts, false);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    private boolean refreshPreKeysIfNecessary(
 | 
				
			||||||
 | 
					            final ServiceIdType serviceIdType,
 | 
				
			||||||
 | 
					            final IdentityKeyPair identityKeyPair,
 | 
				
			||||||
 | 
					            final OneTimePreKeyCounts preKeyCounts,
 | 
				
			||||||
 | 
					            final boolean force
 | 
				
			||||||
 | 
					    ) throws IOException {
 | 
				
			||||||
        List<PreKeyRecord> preKeyRecords = null;
 | 
					        List<PreKeyRecord> preKeyRecords = null;
 | 
				
			||||||
        if (preKeyCounts.getEcCount() < ServiceConfig.PREKEY_MINIMUM_COUNT) {
 | 
					        if (force || preKeyCounts.getEcCount() < ServiceConfig.PREKEY_MINIMUM_COUNT) {
 | 
				
			||||||
            logger.debug("Refreshing {} ec pre keys, because only {} of min {} pre keys remain",
 | 
					            logger.debug("Refreshing {} ec pre keys, because only {} of min {} pre keys remain",
 | 
				
			||||||
                    serviceIdType,
 | 
					                    serviceIdType,
 | 
				
			||||||
                    preKeyCounts.getEcCount(),
 | 
					                    preKeyCounts.getEcCount(),
 | 
				
			||||||
@ -77,13 +107,13 @@ public class PreKeyHelper {
 | 
				
			|||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        SignedPreKeyRecord signedPreKeyRecord = null;
 | 
					        SignedPreKeyRecord signedPreKeyRecord = null;
 | 
				
			||||||
        if (signedPreKeyNeedsRefresh(serviceIdType)) {
 | 
					        if (force || signedPreKeyNeedsRefresh(serviceIdType)) {
 | 
				
			||||||
            logger.debug("Refreshing {} signed pre key.", serviceIdType);
 | 
					            logger.debug("Refreshing {} signed pre key.", serviceIdType);
 | 
				
			||||||
            signedPreKeyRecord = generateSignedPreKey(serviceIdType, identityKeyPair);
 | 
					            signedPreKeyRecord = generateSignedPreKey(serviceIdType, identityKeyPair);
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        List<KyberPreKeyRecord> kyberPreKeyRecords = null;
 | 
					        List<KyberPreKeyRecord> kyberPreKeyRecords = null;
 | 
				
			||||||
        if (preKeyCounts.getKyberCount() < ServiceConfig.PREKEY_MINIMUM_COUNT) {
 | 
					        if (force || preKeyCounts.getKyberCount() < ServiceConfig.PREKEY_MINIMUM_COUNT) {
 | 
				
			||||||
            logger.debug("Refreshing {} kyber pre keys, because only {} of min {} pre keys remain",
 | 
					            logger.debug("Refreshing {} kyber pre keys, because only {} of min {} pre keys remain",
 | 
				
			||||||
                    serviceIdType,
 | 
					                    serviceIdType,
 | 
				
			||||||
                    preKeyCounts.getKyberCount(),
 | 
					                    preKeyCounts.getKyberCount(),
 | 
				
			||||||
@ -92,9 +122,11 @@ public class PreKeyHelper {
 | 
				
			|||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        KyberPreKeyRecord lastResortKyberPreKeyRecord = null;
 | 
					        KyberPreKeyRecord lastResortKyberPreKeyRecord = null;
 | 
				
			||||||
        if (lastResortKyberPreKeyNeedsRefresh(serviceIdType)) {
 | 
					        if (force || lastResortKyberPreKeyNeedsRefresh(serviceIdType)) {
 | 
				
			||||||
            logger.debug("Refreshing {} last resort kyber pre key.", serviceIdType);
 | 
					            logger.debug("Refreshing {} last resort kyber pre key.", serviceIdType);
 | 
				
			||||||
            lastResortKyberPreKeyRecord = generateLastResortKyberPreKey(serviceIdType, identityKeyPair);
 | 
					            lastResortKyberPreKeyRecord = generateLastResortKyberPreKey(serviceIdType,
 | 
				
			||||||
 | 
					                    identityKeyPair,
 | 
				
			||||||
 | 
					                    kyberPreKeyRecords == null ? 0 : kyberPreKeyRecords.size());
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if (signedPreKeyRecord == null
 | 
					        if (signedPreKeyRecord == null
 | 
				
			||||||
@ -157,9 +189,7 @@ public class PreKeyHelper {
 | 
				
			|||||||
        final var accountData = account.getAccountData(serviceIdType);
 | 
					        final var accountData = account.getAccountData(serviceIdType);
 | 
				
			||||||
        final var offset = accountData.getPreKeyMetadata().getNextPreKeyId();
 | 
					        final var offset = accountData.getPreKeyMetadata().getNextPreKeyId();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        var records = KeyUtils.generatePreKeyRecords(offset);
 | 
					        return KeyUtils.generatePreKeyRecords(offset);
 | 
				
			||||||
 | 
					 | 
				
			||||||
        return records;
 | 
					 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    private boolean signedPreKeyNeedsRefresh(ServiceIdType serviceIdType) {
 | 
					    private boolean signedPreKeyNeedsRefresh(ServiceIdType serviceIdType) {
 | 
				
			||||||
@ -210,10 +240,10 @@ public class PreKeyHelper {
 | 
				
			|||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    private KyberPreKeyRecord generateLastResortKyberPreKey(
 | 
					    private KyberPreKeyRecord generateLastResortKyberPreKey(
 | 
				
			||||||
            ServiceIdType serviceIdType, IdentityKeyPair identityKeyPair
 | 
					            ServiceIdType serviceIdType, IdentityKeyPair identityKeyPair, final int offset
 | 
				
			||||||
    ) {
 | 
					    ) {
 | 
				
			||||||
        final var accountData = account.getAccountData(serviceIdType);
 | 
					        final var accountData = account.getAccountData(serviceIdType);
 | 
				
			||||||
        final var signedPreKeyId = accountData.getPreKeyMetadata().getNextKyberPreKeyId();
 | 
					        final var signedPreKeyId = accountData.getPreKeyMetadata().getNextKyberPreKeyId() + offset;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        return KeyUtils.generateKyberPreKeyRecord(signedPreKeyId, identityKeyPair.getPrivateKey());
 | 
					        return KeyUtils.generateKyberPreKeyRecord(signedPreKeyId, identityKeyPair.getPrivateKey());
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
				
			|||||||
@ -114,7 +114,7 @@ public class SignalAccount implements Closeable {
 | 
				
			|||||||
    private static final Logger logger = LoggerFactory.getLogger(SignalAccount.class);
 | 
					    private static final Logger logger = LoggerFactory.getLogger(SignalAccount.class);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    private static final int MINIMUM_STORAGE_VERSION = 1;
 | 
					    private static final int MINIMUM_STORAGE_VERSION = 1;
 | 
				
			||||||
    private static final int CURRENT_STORAGE_VERSION = 8;
 | 
					    private static final int CURRENT_STORAGE_VERSION = 9;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    private final Object LOCK = new Object();
 | 
					    private final Object LOCK = new Object();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -1111,7 +1111,7 @@ public class SignalAccount implements Closeable {
 | 
				
			|||||||
                serviceIdType,
 | 
					                serviceIdType,
 | 
				
			||||||
                preKeyMetadata.nextKyberPreKeyId);
 | 
					                preKeyMetadata.nextKyberPreKeyId);
 | 
				
			||||||
        accountData.getSignalServiceAccountDataStore()
 | 
					        accountData.getSignalServiceAccountDataStore()
 | 
				
			||||||
                .markAllOneTimeEcPreKeysStaleIfNecessary(System.currentTimeMillis());
 | 
					                .markAllOneTimeKyberPreKeysStaleIfNecessary(System.currentTimeMillis());
 | 
				
			||||||
        for (var record : records) {
 | 
					        for (var record : records) {
 | 
				
			||||||
            if (preKeyMetadata.nextKyberPreKeyId != record.getId()) {
 | 
					            if (preKeyMetadata.nextKyberPreKeyId != record.getId()) {
 | 
				
			||||||
                logger.error("Invalid kyber pre key id {}, expected {}",
 | 
					                logger.error("Invalid kyber pre key id {}, expected {}",
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user