在局域网中通过OkHttp post 上传一些大文件,测试后发现文件上传经常占满带宽,影响业务交互。在上传时需要限速。
限速可以服务端限速,也可以客户端限速。服务端限速只是延迟接受,造成TCP 缓冲区拥堵,带宽的问题并没有真正的解决。客户端限速的思路就是写Socket 限速。搜了一下资料,OkHttp 并没有提供限速的接口。
研究了一下OkHttp 的拦截器 Interceptor
一 网络访问的执行 RealCall
在RealCall 的execute 函数中调用getResponseWithInterceptorChain 函数 获取网络的Response。
RealCall.java
@Override protected void execute() {
boolean signalledCallback = false;
try {
Response response = getResponseWithInterceptorChain();
if (retryAndFollowUpInterceptor.isCanceled()) {
signalledCallback = true;
responseCallback.onFailure(RealCall.this, new IOException("Canceled"));
} else {
signalledCallback = true;
responseCallback.onResponse(RealCall.this, response);
}
} catch (IOException e) {
if (signalledCallback) {
// Do not signal the callback twice!
Platform.get().log(INFO, "Callback failure for " + toLoggableString(), e);
} else {
responseCallback.onFailure(RealCall.this, e);
}
} finally {
client.dispatcher().finished(this);
}
}
}
二 OkHttp Interceptor 的实现。
从代码中可以看到,拦截器是别加入到一个数组中。依次是:
- client.interceptors() 自定的拦截器
- retryAndFollowUpInterceptor
- BridgeInterceptor
- CacheInterceptor
- ConnectInterceptor
- CallServerInterceptor
顺序很重要,因为下面拦截器的执行和顺序有关。
真正的网咯访问是在 CallServerInterceptor 中
RealCall.java
Response getResponseWithInterceptorChain() throws IOException {
// Build a full stack of interceptors.
List<Interceptor> interceptors = new ArrayList<>();
interceptors.addAll(client.interceptors());
interceptors.add(retryAndFollowUpInterceptor);
interceptors.add(new BridgeInterceptor(client.cookieJar()));
interceptors.add(new CacheInterceptor(client.internalCache()));
interceptors.add(new ConnectInterceptor(client));
if (!forWebSocket) {
interceptors.addAll(client.networkInterceptors());
}
interceptors.add(new CallServerInterceptor(forWebSocket));
Interceptor.Chain chain = new RealInterceptorChain(
interceptors, null, null, null, 0, originalRequest);
return chain.proceed(originalRequest);
}
三 RealInterceptorChain 的执行
RealInterceptorChain 执行是一个链式的过程。注意在RealCall.java 中构造RealInterceptorChain 传入的index 参数是0, 然后在proceed 函数中又new
一个新的RealInterceptorChain next, next 的index 加一了。
- this.interceptor.intercept(next); 注意把新的拦截器 作为参数传入了。
@Override public Response proceed(Request request) throws IOException {
return proceed(request, streamAllocation, httpCodec, connection);
}
public Response proceed(Request request, StreamAllocation streamAllocation, HttpCodec httpCodec,
RealConnection connection) throws IOException {
// Call the next interceptor in the chain.
RealInterceptorChain next = new RealInterceptorChain(
interceptors, streamAllocation, httpCodec, connection, index + 1, request);
Interceptor interceptor = interceptors.get(index);
Response response = interceptor.intercept(next);
return response;
}
在拦截其中可以根据需要
- 如果是处理Request 先处理拦截器的逻辑,然后链式调用 next 拦截器的 proceed.
- 如果是处理 Response 先调用 next 拦截器然后 处理拦截器的逻辑。
这个设计模式不错。
四 CallServerInterceptor
CallServerInterceptor 是负责网络读写的地方,如果要实现限速,最大的可能就是这里。
request.body().writeTo(bufferedRequestBody);
通过request 的body 向网络写数据。而这个body 来自哪里呢, 在通过post 上传的数据的时候,需要构建 MultipartBody 来封装上传的文件。
public CallServerInterceptor(boolean forWebSocket) {
this.forWebSocket = forWebSocket;
}
@Override public Response intercept(Chain chain) throws IOException {
Response.Builder responseBuilder = null;
if (HttpMethod.permitsRequestBody(request.method()) && request.body() != null) {
if (responseBuilder == null) {
// Write the request body if the "Expect: 100-continue" expectation was met.
Sink requestBodyOut = httpCodec.createRequestBody(request, request.body().contentLength());
BufferedSink bufferedRequestBody = Okio.buffer(requestBodyOut);
request.body().writeTo(bufferedRequestBody);
bufferedRequestBody.close();
} else if (!connection.isMultiplexed()) {
// If the "Expect: 100-continue" expectation wasn't met, prevent the HTTP/1 connection from
// being reused. Otherwise we're still obligated to transmit the request body to leave the
// connection in a consistent state.
streamAllocation.noNewStreams();
}
}
httpCodec.finishRequest();
if (responseBuilder == null) {
responseBuilder = httpCodec.readResponseHeaders(false);
}
Response response = responseBuilder
.request(request)
.handshake(streamAllocation.connection().handshake())
.sentRequestAtMillis(sentRequestMillis)
.receivedResponseAtMillis(System.currentTimeMillis())
.build();
int code = response.code();
if (forWebSocket && code == 101) {
// Connection is upgrading, but we need to ensure interceptors see a non-null response body.
response = response.newBuilder()
.body(Util.EMPTY_RESPONSE)
.build();
} else {
response = response.newBuilder()
.body(httpCodec.openResponseBody(response))
.build();
}
return response;
}
五 RequestBody
通过代码可以看到, RequestBody 为一个抽象类,通过 MultipartBody.create 直接new 出来。那我们的思路就是修改这个RequestBody 的writeTo 函数,控制写Socket 的速度。
public MultipartBody.Part getMultipartBodyPart(){
RequestBody requestFile = MultipartBody.create(MediaType.parse("multipart/form-data"), new File(mFileEncrypt));
MultipartBody.Part fileBody = MultipartBody.Part.createFormData(FILE_ENCRYPT, mFileEncrypt, requestFile);
return fileBody;
}
public static RequestBody create(final @Nullable MediaType contentType, final File file) {
if (file == null) throw new NullPointerException("content == null");
return new RequestBody() {
@Override public @Nullable MediaType contentType() {
return contentType;
}
@Override public long contentLength() {
return file.length();
}
@Override public void writeTo(BufferedSink sink) throws IOException {
Source source = null;
try {
source = Okio.source(file);
sink.writeAll(source);
} finally {
Util.closeQuietly(source);
}
}
};
}
六 RateLimitingRequestBody
修改后的代码如下,针对OkIO 的一些操作 复制了一些代码出来。另外由于编译问题,OkIO.source 方法采用了反射。
public MultipartBody.Part getMultipartBodyPart(){
RequestBody requestFile = RateLimitingRequestBody.createRequestBody(MediaType.parse("multipart/form-data"), new File(mFileEncrypt), UPLOAD_RATE);
MultipartBody.Part fileBody = MultipartBody.Part.createFormData(FILE_ENCRYPT, mFileEncrypt, requestFile);
return fileBody;
}
public class RateLimitingRequestBody extends RequestBody {
private MediaType mContentType;
private File mFile;
private int mMaxRate; // bit/ms
private RateLimitingRequestBody(@Nullable final MediaType contentType, final File file, int rate){
mContentType = contentType;
mFile = file;
mMaxRate = rate;
}
@Override
public MediaType contentType() {
return mContentType;
}
@Override
public void writeTo(BufferedSink sink) throws IOException {
Source source = null;
try {
/*
* reflect instead of Okio.source(mFile) because of build error at platform 23.
* the error is java.nio.** can't find.
*/
// source = Okio.source(mFile);
String className = "okio.Okio";
String methodName = "source";
Class<?> clazz = Class.forName(className);
Method method = clazz.getMethod(methodName, File.class);
source = (Source) method.invoke(null, mFile);
writeAll(sink, source);
} catch (InterruptedException e) {
NLog.exception("writeTo", e);
} catch (NoSuchMethodException e) {
NLog.exception("writeTo", e);
} catch (IllegalAccessException e) {
NLog.exception("writeTo", e);
} catch (InvocationTargetException e) {
NLog.exception("writeTo", e);
} catch (ClassNotFoundException e) {
NLog.exception("writeTo", e);
} finally {
Util.closeQuietly(source);
}
}
public long writeAll(BufferedSink sink, Source source) throws IOException, InterruptedException {
if (source == null) {
throw new IllegalArgumentException("source == null");
} else {
long totalBytesRead = 0L;
long readCount;
long start = System.currentTimeMillis();
while((readCount = source.read(sink.buffer(), 8192L)) != -1L) {
totalBytesRead += readCount;
sink.emitCompleteSegments();
long time = System.currentTimeMillis();
if(time == start) continue;
long rate = (totalBytesRead * 8) / (time - start);
if(rate > mMaxRate/1000){
int sleep = (int) (totalBytesRead * 8 * 1000 / mMaxRate - (time - start));
NLog.v("writeAll","totalBytesRead:"+totalBytesRead+"B "+ " Rate:"+rate*1000+"bits");
NLog.d("writeAll", "sleep:"+sleep);
Thread.sleep(sleep+500);
}
}
long end = System.currentTimeMillis();
long rate = (totalBytesRead * 8 * 1000) / ((end - start));
NLog.e("writeAll","totalBytesRead:"+totalBytesRead+"B "+ " Rate:"+rate+"bits"+" total time:"+(end-start));
return totalBytesRead;
}
}
public static RequestBody createRequestBody(@Nullable final MediaType contentType, final File file, int rate) {
if (file == null) {
throw new NullPointerException("content == null");
} else {
return new RateLimitingRequestBody(contentType, file, rate);
}
}
}