Merge pull request #1707 from famedly/krille/implement-refresh-access-token

feat: Implement handling soft logout
This commit is contained in:
Krille-chan 2024-02-23 13:13:50 +01:00 committed by GitHub
commit 72b5a5e8e7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 239 additions and 2 deletions

View File

@ -89,6 +89,10 @@ class Client extends MatrixApi {
bool shareKeysWithUnverifiedDevices;
Future<void> Function(Client client)? onSoftLogout;
DateTime? accessTokenExpiresAt;
// For CommandsClientExtension
final Map<String, FutureOr<String?> Function(CommandArgs)> commands = {};
final Filter syncFilter;
@ -184,6 +188,13 @@ class Client extends MatrixApi {
this.shareKeysWithUnverifiedDevices = true,
this.enableDehydratedDevices = false,
this.receiptsPublicByDefault = true,
/// Implement your https://spec.matrix.org/v1.9/client-server-api/#soft-logout
/// logic here.
/// Set this to `refreshAccessToken()` for the easiest way to handle the
/// most common reason for soft logouts.
/// You can also perform a new login here by passing the existing deviceId.
this.onSoftLogout,
}) : syncFilter = syncFilter ??
Filter(
room: RoomFilter(
@ -234,6 +245,41 @@ class Client extends MatrixApi {
registerDefaultCommands();
}
/// Fetches the refreshToken from the database and tries to get a new
/// access token from the server and then stores it correctly. Unlike the
/// pure API call of `Client.refresh()` this handles the complete soft
/// logout case.
/// Throws an Exception if there is no refresh token available or the
/// client is not logged in.
Future<void> refreshAccessToken() async {
final storedClient = await database?.getClient(clientName);
final refreshToken = storedClient?.tryGet<String>('refresh_token');
if (refreshToken == null) {
throw Exception('No refresh token available');
}
final homeserverUrl = homeserver?.toString();
final userId = userID;
final deviceId = deviceID;
if (homeserverUrl == null || userId == null || deviceId == null) {
throw Exception('Cannot refresh access token when not logged in');
}
final tokenResponse = await refresh(refreshToken);
accessToken = tokenResponse.accessToken;
await database?.updateClient(
homeserverUrl,
tokenResponse.accessToken,
accessTokenExpiresAt,
tokenResponse.refreshToken,
userId,
deviceId,
deviceName,
prevBatch,
encryption?.pickledOlmAccount,
);
}
/// The required name for this client.
final String clientName;
@ -485,6 +531,7 @@ class Client extends MatrixApi {
deviceId: deviceId,
initialDeviceDisplayName: initialDeviceDisplayName,
inhibitLogin: inhibitLogin,
refreshToken: refreshToken ?? onSoftLogout != null,
);
// Connect if there is an access token in the response.
@ -496,8 +543,15 @@ class Client extends MatrixApi {
throw Exception(
'Registered but token, device ID, user ID or homeserver is null.');
}
final expiresInMs = response.expiresInMs;
final tokenExpiresAt = expiresInMs == null
? null
: DateTime.now().add(Duration(milliseconds: expiresInMs));
await init(
newToken: accessToken,
newTokenExpiresAt: tokenExpiresAt,
newRefreshToken: response.refreshToken,
newUserID: userId,
newHomeserver: homeserver,
newDeviceName: initialDeviceDisplayName ?? '',
@ -548,6 +602,7 @@ class Client extends MatrixApi {
medium: medium,
// ignore: deprecated_member_use
address: address,
refreshToken: refreshToken ?? onSoftLogout != null,
);
// Connect if there is an access token in the response.
@ -558,8 +613,16 @@ class Client extends MatrixApi {
if (homeserver_ == null) {
throw Exception('Registered but homerserver is null.');
}
final expiresInMs = response.expiresInMs;
final tokenExpiresAt = expiresInMs == null
? null
: DateTime.now().add(Duration(milliseconds: expiresInMs));
await init(
newToken: accessToken,
newTokenExpiresAt: tokenExpiresAt,
newRefreshToken: response.refreshToken,
newUserID: userId,
newHomeserver: homeserver_,
newDeviceName: initialDeviceDisplayName ?? '',
@ -1474,6 +1537,8 @@ class Client extends MatrixApi {
/// `userDeviceKeysLoading` where it is necessary.
Future<void> init({
String? newToken,
DateTime? newTokenExpiresAt,
String? newRefreshToken,
Uri? newHomeserver,
String? newUserID,
String? newDeviceName,
@ -1531,6 +1596,11 @@ class Client extends MatrixApi {
_id = account['client_id'];
homeserver = Uri.parse(account['homeserver_url']);
accessToken = this.accessToken = account['token'];
final tokenExpiresAtMs =
int.tryParse(account.tryGet<String>('token_expires_at') ?? '');
accessTokenExpiresAt = tokenExpiresAtMs == null
? null
: DateTime.fromMillisecondsSinceEpoch(tokenExpiresAtMs);
userID = _userID = account['user_id'];
_deviceID = account['device_id'];
_deviceName = account['device_name'];
@ -1540,6 +1610,7 @@ class Client extends MatrixApi {
}
if (newToken != null) {
accessToken = this.accessToken = newToken;
accessTokenExpiresAt = newTokenExpiresAt;
homeserver = newHomeserver;
userID = _userID = newUserID;
_deviceID = newDeviceID;
@ -1547,6 +1618,7 @@ class Client extends MatrixApi {
olmAccount = newOlmAccount;
} else {
accessToken = this.accessToken = newToken ?? accessToken;
accessTokenExpiresAt = newTokenExpiresAt ?? accessTokenExpiresAt;
homeserver = newHomeserver ?? homeserver;
userID = _userID = newUserID ?? userID;
_deviceID = newDeviceID ?? _deviceID;
@ -1587,6 +1659,8 @@ class Client extends MatrixApi {
await database.updateClient(
homeserver.toString(),
accessToken,
accessTokenExpiresAt,
newRefreshToken,
userID,
_deviceID,
_deviceName,
@ -1598,6 +1672,8 @@ class Client extends MatrixApi {
clientName,
homeserver.toString(),
accessToken,
accessTokenExpiresAt,
newRefreshToken,
userID,
_deviceID,
_deviceName,
@ -1744,6 +1820,15 @@ class Client extends MatrixApi {
Object? syncError;
await _checkSyncFilter();
// Call onSoftLogout 5 minutes before access token expires to prevent
// failing network requests.
final tokenExpiresAt = accessTokenExpiresAt;
if (onSoftLogout != null &&
tokenExpiresAt != null &&
tokenExpiresAt.difference(DateTime.now()) <= Duration(minutes: 5)) {
await onSoftLogout?.call(this);
}
// The timeout we send to the server for the sync loop. It says to the
// server that we want to receive an empty sync response after this
// amount of time if nothing happens.
@ -1822,8 +1907,19 @@ class Client extends MatrixApi {
onSyncStatus.add(SyncStatusUpdate(SyncStatus.error,
error: SdkError(exception: e, stackTrace: s)));
if (e.error == MatrixError.M_UNKNOWN_TOKEN) {
Logs().w('The user has been logged out!');
await clear();
final onSoftLogout = this.onSoftLogout;
if (e.raw.tryGet<bool>('soft_logout') == true && onSoftLogout != null) {
Logs().w('The user has been soft logged out! Try to login again...');
try {
await onSoftLogout(this);
} catch (e, s) {
Logs().e('Unable to login again', e, s);
await clear();
}
} else {
Logs().w('The user has been logged out!');
await clear();
}
}
} on SyncConnectionException catch (e, s) {
Logs().w('Syncloop failed: Client has not connection to the server');
@ -3108,10 +3204,16 @@ class Client extends MatrixApi {
Logs().i('Found data in the legacy database!');
onMigration?.call();
_id = migrateClient['client_id'];
final tokenExpiresAtMs =
int.tryParse(migrateClient.tryGet<String>('token_expires_at') ?? '');
await database.insertClient(
clientName,
migrateClient['homeserver_url'],
migrateClient['token'],
tokenExpiresAtMs == null
? null
: DateTime.fromMillisecondsSinceEpoch(tokenExpiresAtMs),
migrateClient['refresh_token'],
migrateClient['user_id'],
migrateClient['device_id'],
migrateClient['device_name'],

View File

@ -33,6 +33,8 @@ abstract class DatabaseApi {
Future updateClient(
String homeserverUrl,
String token,
DateTime? tokenExpiresAt,
String? refreshToken,
String userId,
String? deviceId,
String? deviceName,
@ -44,6 +46,8 @@ abstract class DatabaseApi {
String name,
String homeserverUrl,
String token,
DateTime? tokenExpiresAt,
String? refreshToken,
String userId,
String? deviceId,
String? deviceName,

View File

@ -785,6 +785,8 @@ class HiveCollectionsDatabase extends DatabaseApi {
String name,
String homeserverUrl,
String token,
DateTime? tokenExpiresAt,
String? refreshToken,
String userId,
String? deviceId,
String? deviceName,
@ -794,6 +796,19 @@ class HiveCollectionsDatabase extends DatabaseApi {
await _clientBox.put('homeserver_url', homeserverUrl);
await _clientBox.put('token', token);
await _clientBox.put('user_id', userId);
if (refreshToken == null) {
await _clientBox.delete('refresh_token');
} else {
await _clientBox.put('refresh_token', refreshToken);
}
if (tokenExpiresAt == null) {
await _clientBox.delete('token_expires_at');
} else {
await _clientBox.put(
'token_expires_at',
tokenExpiresAt.millisecondsSinceEpoch.toString(),
);
}
if (deviceId == null) {
await _clientBox.delete('device_id');
} else {
@ -1371,6 +1386,8 @@ class HiveCollectionsDatabase extends DatabaseApi {
Future<void> updateClient(
String homeserverUrl,
String token,
DateTime? tokenExpiresAt,
String? refreshToken,
String userId,
String? deviceId,
String? deviceName,
@ -1380,6 +1397,17 @@ class HiveCollectionsDatabase extends DatabaseApi {
await transaction(() async {
await _clientBox.put('homeserver_url', homeserverUrl);
await _clientBox.put('token', token);
if (tokenExpiresAt == null) {
await _clientBox.delete('token_expires_at');
} else {
await _clientBox.put('token_expires_at',
tokenExpiresAt.millisecondsSinceEpoch.toString());
}
if (refreshToken == null) {
await _clientBox.delete('refresh_token');
} else {
await _clientBox.put('refresh_token', refreshToken);
}
await _clientBox.put('user_id', userId);
if (deviceId == null) {
await _clientBox.delete('device_id');

View File

@ -750,6 +750,8 @@ class FamedlySdkHiveDatabase extends DatabaseApi with ZoneTransactionMixin {
String name,
String homeserverUrl,
String token,
DateTime? tokenExpiresAt,
String? refreshToken,
String userId,
String? deviceId,
String? deviceName,
@ -757,6 +759,9 @@ class FamedlySdkHiveDatabase extends DatabaseApi with ZoneTransactionMixin {
String? olmAccount) async {
await _clientBox.put('homeserver_url', homeserverUrl);
await _clientBox.put('token', token);
await _clientBox.put(
'token_expires_at', tokenExpiresAt?.millisecondsSinceEpoch.toString());
await _clientBox.put('refresh_token', refreshToken);
await _clientBox.put('user_id', userId);
await _clientBox.put('device_id', deviceId);
await _clientBox.put('device_name', deviceName);
@ -1314,6 +1319,8 @@ class FamedlySdkHiveDatabase extends DatabaseApi with ZoneTransactionMixin {
Future<void> updateClient(
String homeserverUrl,
String token,
DateTime? tokenExpiresAt,
String? refreshToken,
String userId,
String? deviceId,
String? deviceName,
@ -1322,6 +1329,9 @@ class FamedlySdkHiveDatabase extends DatabaseApi with ZoneTransactionMixin {
) async {
await _clientBox.put('homeserver_url', homeserverUrl);
await _clientBox.put('token', token);
await _clientBox.put(
'token_expires_at', tokenExpiresAt?.millisecondsSinceEpoch.toString());
await _clientBox.put('refresh_token', refreshToken);
await _clientBox.put('user_id', userId);
await _clientBox.put('device_id', deviceId);
await _clientBox.put('device_name', deviceName);

View File

@ -727,6 +727,8 @@ class MatrixSdkDatabase extends DatabaseApi {
String name,
String homeserverUrl,
String token,
DateTime? tokenExpiresAt,
String? refreshToken,
String userId,
String? deviceId,
String? deviceName,
@ -735,6 +737,17 @@ class MatrixSdkDatabase extends DatabaseApi {
await transaction(() async {
await _clientBox.put('homeserver_url', homeserverUrl);
await _clientBox.put('token', token);
if (tokenExpiresAt == null) {
await _clientBox.delete('token_expires_at');
} else {
await _clientBox.put('token_expires_at',
tokenExpiresAt.millisecondsSinceEpoch.toString());
}
if (refreshToken == null) {
await _clientBox.delete('refresh_token');
} else {
await _clientBox.put('refresh_token', refreshToken);
}
await _clientBox.put('user_id', userId);
if (deviceId == null) {
await _clientBox.delete('device_id');
@ -1343,6 +1356,8 @@ class MatrixSdkDatabase extends DatabaseApi {
Future<void> updateClient(
String homeserverUrl,
String token,
DateTime? tokenExpiresAt,
String? refreshToken,
String userId,
String? deviceId,
String? deviceName,
@ -1352,6 +1367,17 @@ class MatrixSdkDatabase extends DatabaseApi {
await transaction(() async {
await _clientBox.put('homeserver_url', homeserverUrl);
await _clientBox.put('token', token);
if (tokenExpiresAt == null) {
await _clientBox.delete('token_expires_at');
} else {
await _clientBox.put('token_expires_at',
tokenExpiresAt.millisecondsSinceEpoch.toString());
}
if (refreshToken == null) {
await _clientBox.delete('refresh_token');
} else {
await _clientBox.put('refresh_token', refreshToken);
}
await _clientBox.put('user_id', userId);
if (deviceId == null) {
await _clientBox.delete('device_id');

View File

@ -964,6 +964,35 @@ void main() {
await client.dispose(closeDatabase: true);
});
test('refreshAccessToken', () async {
final client = await getClient();
expect(client.accessToken, 'abcd');
await client.refreshAccessToken();
expect(client.accessToken, 'a_new_token');
});
test('handleSoftLogout', () async {
final client = await getClient();
expect(client.accessToken, 'abcd');
var softLoggedOut = 0;
client.onSoftLogout = (client) {
softLoggedOut++;
return client.refreshAccessToken();
};
FakeMatrixApi.expectedAccessToken = 'a_new_token';
await client.oneShotSync();
await client.oneShotSync();
FakeMatrixApi.expectedAccessToken = null;
expect(client.accessToken, 'a_new_token');
expect(softLoggedOut, 1);
final storedClient = await client.database?.getClient(client.clientName);
expect(storedClient?.tryGet<String>('token'), 'a_new_token');
expect(
storedClient?.tryGet<String>('refresh_token'),
'another_new_token',
);
});
test('object equality', () async {
final time1 = DateTime.fromMillisecondsSinceEpoch(1);
final time2 = DateTime.fromMillisecondsSinceEpoch(0);

View File

@ -129,10 +129,13 @@ void main() {
await database.getClient('name');
});
test('insertClient', () async {
final now = DateTime.now();
await database.insertClient(
'name',
'homeserverUrl',
'token',
now,
'refresh_token',
'userId',
'deviceId',
'deviceName',
@ -142,11 +145,17 @@ void main() {
final client = await database.getClient('name');
expect(client?['token'], 'token');
expect(
client?['token_expires_at'],
now.millisecondsSinceEpoch.toString(),
);
});
test('updateClient', () async {
await database.updateClient(
'homeserverUrl',
'token_different',
DateTime.now(),
'refresh_token',
'userId',
'deviceId',
'deviceName',

View File

@ -32,12 +32,14 @@ Future<Client> getClient() async {
'testclient',
httpClient: FakeMatrixApi(),
databaseBuilder: getDatabase,
onSoftLogout: (client) => client.refreshAccessToken(),
);
FakeMatrixApi.client = client;
await client.checkHomeserver(Uri.parse('https://fakeServer.notExisting'),
checkWellKnown: false);
await client.init(
newToken: 'abcd',
newRefreshToken: 'refresh_abcd',
newUserID: '@test:fakeServer.notExisting',
newHomeserver: client.homeserver,
newDeviceName: 'Text Matrix Client',

View File

@ -39,6 +39,8 @@ Map<String, dynamic> decodeJson(dynamic data) {
}
class FakeMatrixApi extends BaseClient {
static String? expectedAccessToken;
static Map<String, List<dynamic>> get calledEndpoints =>
currentApi!._calledEndpoints;
static int get eventCounter => currentApi!._eventCounter;
@ -129,6 +131,23 @@ class FakeMatrixApi extends BaseClient {
'<html><head></head><body>Not found...</body></html>', 404);
}
if (!{
'/client/v3/refresh',
'/client/v3/login',
'/client/v3/register',
}.contains(action) &&
expectedAccessToken != null &&
request.headers['Authorization'] != 'Bearer $expectedAccessToken') {
return Response(
jsonEncode({
'errcode': 'M_UNKNOWN_TOKEN',
'error': 'Soft logged out',
'soft_logout': true,
}),
401,
);
}
// Call API
(_calledEndpoints[action] ??= <dynamic>[]).add(data);
final act = api[method]?[action];
@ -2013,6 +2032,11 @@ class FakeMatrixApi extends BaseClient {
},
},
'POST': {
'/client/v3/refresh': (var req) => {
'access_token': 'a_new_token',
'expires_in_ms': 60000,
'refresh_token': 'another_new_token'
},
'/client/v3/delete_devices': (var req) => {},
'/client/v3/account/3pid/add': (var req) => {},
'/client/v3/account/3pid/bind': (var req) => {},
@ -2397,6 +2421,7 @@ class FakeMatrixApi extends BaseClient {
'/client/v3/login': (var req) => {
'user_id': '@test:fakeServer.notExisting',
'access_token': 'abc123',
'refresh_token': 'refresh_abc123',
'device_id': 'GHTYAJCE',
'well_known': {
'm.homeserver': {'base_url': 'https://example.org'},

View File

@ -32,6 +32,8 @@ void main() {
'testclient',
'https://example.org',
'blubb',
null,
null,
'@test:example.org',
null,
null,