001package com.box.sdk;
002
003import com.github.luben.zstd.ZstdInputStream;
004import java.io.IOException;
005import java.io.InputStream;
006import okhttp3.Interceptor;
007import okhttp3.MediaType;
008import okhttp3.Request;
009import okhttp3.Response;
010import okhttp3.ResponseBody;
011import okio.BufferedSource;
012import okio.Okio;
013import okio.Source;
014import org.jetbrains.annotations.NotNull;
015
016/**
017 * Interceptor that adds zstd compression support to API requests.
018 * This interceptor adds zstd to the Accept-Encoding header and handles decompression of zstd responses.
019 */
020public class ZstdInterceptor implements Interceptor {
021    @NotNull
022    @Override
023    public Response intercept(Chain chain) throws IOException {
024        Request request = chain.request();
025
026        // Add zstd to the Accept-Encoding header
027        String acceptEncoding;
028        String acceptEncodingHeader = request.header("Accept-Encoding");
029        if (acceptEncodingHeader == null || acceptEncodingHeader.isEmpty()) {
030            acceptEncoding = "zstd";
031        } else {
032            acceptEncoding = acceptEncodingHeader + ", zstd";
033        }
034
035        Request compressedRequest = request.newBuilder()
036            .removeHeader("Accept-Encoding")
037            .addHeader("Accept-Encoding", acceptEncoding)
038            .build();
039
040        Response response = chain.proceed(compressedRequest);
041        String contentEncoding = response.header("Content-Encoding");
042
043        // Only handle zstd encoded responses, let OkHttp handle gzip and others
044        if (contentEncoding == null || !contentEncoding.equalsIgnoreCase("zstd")) {
045            return response;
046        }
047
048        ResponseBody originalBody = response.body();
049        if (originalBody == null) {
050            return response;
051        }
052
053        // Create a streaming response body
054        ResponseBody decompressedBody = createStreamingResponseBody(originalBody);
055
056        return response.newBuilder()
057                .body(decompressedBody)
058                .addHeader("X-Content-Encoding", "zstd")
059                .removeHeader("Content-Encoding")
060                .removeHeader("Content-Length")
061                .build();
062    }
063
064    /**
065     * Wraps the original response body in a streaming Zstd decompressor.
066     */
067    private ResponseBody createStreamingResponseBody(ResponseBody originalBody) {
068        return new ResponseBody() {
069            @Override
070            public MediaType contentType() {
071                return originalBody.contentType();
072            }
073
074            @Override
075            public long contentLength() {
076                return -1;
077            }
078
079            @Override
080            public BufferedSource source() {
081                InputStream decompressedStream;
082                try {
083                    decompressedStream = new ZstdInputStream(originalBody.byteStream());
084                } catch (IOException e) {
085                    throw new RuntimeException("Failed to create ZstdInputStream", e);
086                }
087
088                Source source = Okio.source(decompressedStream);
089                return Okio.buffer(source);
090            }
091        };
092    }
093}