Skip to content

Commit 4fe85d2

Browse files
committed
Refresh OAuth tokens
1 parent 196d2d7 commit 4fe85d2

File tree

5 files changed

+220
-10
lines changed

5 files changed

+220
-10
lines changed

src/matrix/Client.js

+35-2
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import {AbortableOperation} from "../utils/AbortableOperation";
2121
import {ObservableValue} from "../observable/ObservableValue";
2222
import {HomeServerApi} from "./net/HomeServerApi";
2323
import {OidcApi} from "./net/OidcApi";
24+
import {TokenRefresher} from "./net/TokenRefresher";
2425
import {Reconnector, ConnectionStatus} from "./net/Reconnector";
2526
import {ExponentialRetryDelay} from "./net/ExponentialRetryDelay";
2627
import {MediaRepository} from "./net/MediaRepository";
@@ -182,7 +183,7 @@ export class Client {
182183
}
183184

184185
if (loginData.expires_in) {
185-
sessionInfo.expiresAt = clock.now() + loginData.expires_in * 1000;
186+
sessionInfo.accessTokenExpiresAt = clock.now() + loginData.expires_in * 1000;
186187
}
187188

188189
if (loginData.oidc_issuer) {
@@ -242,9 +243,41 @@ export class Client {
242243
retryDelay: new ExponentialRetryDelay(clock.createTimeout),
243244
createMeasure: clock.createMeasure
244245
});
246+
247+
let accessToken;
248+
249+
if (sessionInfo.oidcIssuer) {
250+
const oidcApi = new OidcApi({
251+
issuer: sessionInfo.oidcIssuer,
252+
clientId: "hydrogen-web",
253+
request: this._platform.request,
254+
encoding: this._platform.encoding,
255+
});
256+
257+
// TODO: stop/pause the refresher?
258+
const tokenRefresher = new TokenRefresher({
259+
oidcApi,
260+
clock: this._platform.clock,
261+
accessToken: sessionInfo.accessToken,
262+
accessTokenExpiresAt: sessionInfo.accessTokenExpiresAt,
263+
refreshToken: sessionInfo.refreshToken,
264+
anticipation: 30 * 1000,
265+
});
266+
267+
tokenRefresher.token.subscribe(t => {
268+
this._platform.sessionInfoStorage.updateToken(sessionInfo.id, t.accessToken, t.accessTokenExpiresAt, t.refreshToken);
269+
});
270+
271+
await tokenRefresher.start();
272+
273+
accessToken = tokenRefresher.accessToken;
274+
} else {
275+
accessToken = new ObservableValue(sessionInfo.accessToken);
276+
}
277+
245278
const hsApi = new HomeServerApi({
246279
homeserver: sessionInfo.homeServer,
247-
accessToken: sessionInfo.accessToken,
280+
accessToken,
248281
request: this._platform.request,
249282
reconnector: this._reconnector,
250283
});

src/matrix/net/HomeServerApi.ts

+13-7
Original file line numberDiff line numberDiff line change
@@ -30,18 +30,18 @@ const DEHYDRATION_PREFIX = "/_matrix/client/unstable/org.matrix.msc2697.v2";
3030

3131
type Options = {
3232
homeserver: string;
33-
accessToken: string;
33+
accessToken: BaseObservableValue<string>;
3434
request: RequestFunction;
3535
reconnector: Reconnector;
3636
};
3737

3838
export class HomeServerApi {
3939
private readonly _homeserver: string;
40-
private readonly _accessToken: string;
40+
private readonly _accessToken: BaseObservableValue<string>;
4141
private readonly _requestFn: RequestFunction;
4242
private readonly _reconnector: Reconnector;
4343

44-
constructor({homeserver, accessToken, request, reconnector}: Options) {
44+
constructor({ homeserver, accessToken, request, reconnector }: Options) {
4545
// store these both in a closure somehow so it's harder to get at in case of XSS?
4646
// one could change the homeserver as well so the token gets sent there, so both must be protected from read/write
4747
this._homeserver = homeserver;
@@ -54,7 +54,7 @@ export class HomeServerApi {
5454
return this._homeserver + prefix + csPath;
5555
}
5656

57-
private _baseRequest(method: RequestMethod, url: string, queryParams?: Record<string, any>, body?: Record<string, any>, options?: IRequestOptions, accessToken?: string): IHomeServerRequest {
57+
private _baseRequest(method: RequestMethod, url: string, queryParams?: Record<string, any>, body?: Record<string, any>, options?: IRequestOptions, accessTokenSource?: BaseObservableValue<string>): IHomeServerRequest {
5858
const queryString = encodeQueryParams(queryParams);
5959
url = `${url}?${queryString}`;
6060
let log: ILogItem | undefined;
@@ -68,9 +68,14 @@ export class HomeServerApi {
6868
}
6969
let encodedBody: EncodedBody["body"];
7070
const headers: Map<string, string | number> = new Map();
71+
72+
let accessToken: string | null = null;
7173
if (options?.accessTokenOverride) {
7274
accessToken = options.accessTokenOverride;
75+
} else if (accessTokenSource) {
76+
accessToken = accessTokenSource.get();
7377
}
78+
7479
if (accessToken) {
7580
headers.set("Authorization", `Bearer ${accessToken}`);
7681
}
@@ -91,7 +96,7 @@ export class HomeServerApi {
9196
});
9297

9398
const hsRequest = new HomeServerRequest(method, url, requestResult, log);
94-
99+
95100
if (this._reconnector) {
96101
hsRequest.response().catch(err => {
97102
// Some endpoints such as /sync legitimately time-out
@@ -282,11 +287,12 @@ export class HomeServerApi {
282287

283288
claimDehydratedDevice(deviceId: string, options: IRequestOptions): IHomeServerRequest {
284289
options.prefix = DEHYDRATION_PREFIX;
285-
return this._post(`/dehydrated_device/claim`, {}, {device_id: deviceId}, options);
290+
return this._post(`/dehydrated_device/claim`, {}, { device_id: deviceId }, options);
286291
}
287292
}
288293

289-
import {Request as MockRequest} from "../../mocks/Request.js";
294+
import { Request as MockRequest } from "../../mocks/Request.js";
295+
import { BaseObservableValue } from "../../observable/ObservableValue";
290296

291297
export function tests() {
292298
return {

src/matrix/net/TokenRefresher.ts

+125
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
/*
2+
Copyright 2022 The Matrix.org Foundation C.I.C.
3+
4+
Licensed under the Apache License, Version 2.0 (the 'License');
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an 'AS IS' BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
import { BaseObservableValue, ObservableValue } from "../../observable/ObservableValue";
18+
import type { Clock, Timeout } from "../../platform/web/dom/Clock";
19+
import { OidcApi } from "./OidcApi";
20+
21+
type Token = {
22+
accessToken: string,
23+
accessTokenExpiresAt: number,
24+
refreshToken: string,
25+
};
26+
27+
28+
export class TokenRefresher {
29+
private _token: ObservableValue<Token>;
30+
private _accessToken: BaseObservableValue<string>;
31+
private _anticipation: number;
32+
private _clock: Clock;
33+
private _oidcApi: OidcApi;
34+
private _timeout: Timeout
35+
36+
constructor({
37+
oidcApi,
38+
refreshToken,
39+
accessToken,
40+
accessTokenExpiresAt,
41+
anticipation,
42+
clock,
43+
}: {
44+
oidcApi: OidcApi,
45+
refreshToken: string,
46+
accessToken: string,
47+
accessTokenExpiresAt: number,
48+
anticipation: number,
49+
clock: Clock,
50+
}) {
51+
this._token = new ObservableValue({
52+
accessToken,
53+
accessTokenExpiresAt,
54+
refreshToken,
55+
});
56+
this._accessToken = this._token.map(t => t.accessToken);
57+
58+
this._anticipation = anticipation;
59+
this._oidcApi = oidcApi;
60+
this._clock = clock;
61+
}
62+
63+
async start() {
64+
if (this.needsRenewing) {
65+
await this.renew();
66+
}
67+
68+
this._renewingLoop();
69+
}
70+
71+
stop() {
72+
// TODO
73+
}
74+
75+
get needsRenewing() {
76+
const remaining = this._token.get().accessTokenExpiresAt - this._clock.now();
77+
const anticipated = remaining - this._anticipation;
78+
return anticipated < 0;
79+
}
80+
81+
async _renewingLoop() {
82+
while (true) {
83+
const remaining =
84+
this._token.get().accessTokenExpiresAt - this._clock.now();
85+
const anticipated = remaining - this._anticipation;
86+
87+
if (anticipated > 0) {
88+
this._timeout = this._clock.createTimeout(anticipated);
89+
await this._timeout.elapsed();
90+
}
91+
92+
await this.renew();
93+
}
94+
}
95+
96+
async renew() {
97+
let refreshToken = this._token.get().refreshToken;
98+
const response = await this._oidcApi
99+
.refreshToken({
100+
refreshToken,
101+
});
102+
103+
if (typeof response.expires_in !== "number") {
104+
throw new Error("Refreshed access token does not expire");
105+
}
106+
107+
if (response.refresh_token) {
108+
refreshToken = response.refresh_token;
109+
}
110+
111+
this._token.set({
112+
refreshToken,
113+
accessToken: response.access_token,
114+
accessTokenExpiresAt: this._clock.now() + response.expires_in * 1000,
115+
});
116+
}
117+
118+
get accessToken(): BaseObservableValue<string> {
119+
return this._accessToken;
120+
}
121+
122+
get token(): BaseObservableValue<Token> {
123+
return this._token;
124+
}
125+
}

src/matrix/sessioninfo/localstorage/SessionInfoStorage.ts

+15-1
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@ interface ISessionInfo {
2121
homeserver: string;
2222
homeServer: string; // deprecate this over time
2323
accessToken: string;
24+
accessTokenExpiresAt?: number;
2425
refreshToken?: string;
25-
expiresAt?: number;
2626
oidcIssuer?: string;
2727
lastUsed: number;
2828
}
@@ -31,6 +31,7 @@ interface ISessionInfo {
3131
interface ISessionInfoStorage {
3232
getAll(): Promise<ISessionInfo[]>;
3333
updateLastUsed(id: string, timestamp: number): Promise<void>;
34+
updateToken(id: string, accessToken: string, accessTokenExpiresAt: number, refreshToken: string): Promise<void>;
3435
get(id: string): Promise<ISessionInfo | undefined>;
3536
add(sessionInfo: ISessionInfo): Promise<void>;
3637
delete(sessionId: string): Promise<void>;
@@ -65,6 +66,19 @@ export class SessionInfoStorage implements ISessionInfoStorage {
6566
}
6667
}
6768

69+
async updateToken(id: string, accessToken: string, accessTokenExpiresAt: number, refreshToken: string): Promise<void> {
70+
const sessions = await this.getAll();
71+
if (sessions) {
72+
const session = sessions.find(session => session.id === id);
73+
if (session) {
74+
session.accessToken = accessToken;
75+
session.accessTokenExpiresAt = accessTokenExpiresAt;
76+
session.refreshToken = refreshToken;
77+
localStorage.setItem(this._name, JSON.stringify(sessions));
78+
}
79+
}
80+
}
81+
6882
async get(id: string): Promise<ISessionInfo | undefined> {
6983
const sessions = await this.getAll();
7084
if (sessions) {

src/observable/ObservableValue.ts

+32
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,10 @@ export abstract class BaseObservableValue<T> extends BaseObservable<(value: T) =
3939
flatMap<C>(mapper: (value: T) => (BaseObservableValue<C> | undefined)): BaseObservableValue<C | undefined> {
4040
return new FlatMapObservableValue<T, C>(this, mapper);
4141
}
42+
43+
map<C>(mapper: (value: T) => C): BaseObservableValue<C> {
44+
return new MappedObservableValue<T, C>(this, mapper);
45+
}
4246
}
4347

4448
interface IWaitHandle<T> {
@@ -174,6 +178,34 @@ export class FlatMapObservableValue<P, C> extends BaseObservableValue<C | undefi
174178
}
175179
}
176180

181+
export class MappedObservableValue<P, C> extends BaseObservableValue<C> {
182+
private sourceSubscription?: SubscriptionHandle;
183+
184+
constructor(
185+
private readonly source: BaseObservableValue<P>,
186+
private readonly mapper: (value: P) => C
187+
) {
188+
super();
189+
}
190+
191+
onUnsubscribeLast() {
192+
super.onUnsubscribeLast();
193+
this.sourceSubscription = this.sourceSubscription!();
194+
}
195+
196+
onSubscribeFirst() {
197+
super.onSubscribeFirst();
198+
this.sourceSubscription = this.source.subscribe(() => {
199+
this.emit(this.get());
200+
});
201+
}
202+
203+
get(): C {
204+
const sourceValue = this.source.get();
205+
return this.mapper(sourceValue);
206+
}
207+
}
208+
177209
export function tests() {
178210
return {
179211
"set emits an update": assert => {

0 commit comments

Comments
 (0)