基于文件排序的简单实现

当我们需要从数据库查询大量数据的时候,我们可以分页查询,然后将没页数据排序后写入磁盘文件,防止内存溢出,然后再从每个文件中取部分数据出来,排序后写入合并文件

package com.ls.sort;

import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.IOException;
import java.lang.reflect.Field;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.TreeSet;

/**
 * 文件排序
 *
 * @date : 2023/3/17 16:18
 **/
public class FileSort<T> {

    private List<FileData<T>> fileDatas = new ArrayList<>();

    private List<T> list = new ArrayList<>();

    private Comparator<T> comparator;

    // 多少条数据写一次文件
    private int writeToFileNum = 1000;

    private ObjectWriter mergeObjectWriter;

    public FileSort(Comparator<T> comparator, int writeToFileNum){
        this.comparator = comparator;
        this.writeToFileNum = writeToFileNum;
    }

    public void add(T obj) {
        list.add(obj);
        if (list.size() == writeToFileNum) {
            list.sort(comparator);
            writeToFile(list);
            list.clear();
        }
    }

    private void writeToFile(List<T> list) {
        String fileName = "temp_%d.data";
        File file = new File(String.format(fileName, fileDatas.size()));

        FileData fileData = new FileData(file, list.size(), list.get(0).getClass(), new HeapByteBuf(1024));
        fileData.writeToFile(list);

        fileDatas.add(fileData);
    }

    private void writeToMergeFile(List<T> list) {
        if (mergeObjectWriter == null) {
            mergeObjectWriter = new ObjectWriter(
                    list.get(0).getClass(), new HeapByteBuf(1024), new File("merge.data"));
        }
        try {
            for (T t : list) {
                mergeObjectWriter.writeObject(t);
            }
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    public void sort() {
        int fetchSize = writeToFileNum / fileDatas.size();

        // 每次拿取数量
        for (FileData<T> fileData : fileDatas) {
            fileData.setFetchSize(fetchSize);
        }

        List<T> list = new ArrayList<>(writeToFileNum);

        // 有数据的文件
        TreeSet<FileData<T>> treeSet = new TreeSet<>((o1, o2) -> comparator.compare(o1.getNext(), o2.getNext()));
        treeSet.addAll(fileDatas);

        while (!treeSet.isEmpty()) {
            // 拿出最小值
            FileData<T> fileData = treeSet.pollFirst();
            // 放入最小值
            T next = fileData.next();
            list.add(next);
            System.out.println("排序后:" + next);

            // 写文件
            if (list.size() == writeToFileNum) {
                writeToMergeFile(list);
                list.clear();
            }

            // 还有数据
            if (fileData.hasNext()) {
                treeSet.add(fileData);
            }
        }
    }

}

class FileData<T> {

    private ObjectWriter<T> writer;

    private ObjectReader<T> reader;

    private List<T> partList = new ArrayList<>();

    private Integer fetchStart = 0;

    private Integer fetchEnd = 0;

    private Integer readIndex = 0;

    private Integer writeIndex = 0;

    private Integer fetchSize;

    public FileData(File file, Integer dataNum, Class<T> clazz, ByteBuf buf){
        this.writeIndex = dataNum;
        this.reader = new ObjectReader(clazz, buf, file);
        this.writer = new ObjectWriter<>(clazz, buf, file);
    }

    public void setFetchSize(Integer fetchSize) {
        this.fetchSize = fetchSize;
    }

    public boolean hasNext() {
        return fetchStart + readIndex < writeIndex;
    }

    public T getNext() {
        if (readIndex >= partList.size()) {
            readFromFile();
        }
        return partList.get(readIndex);
    }

    public T next() {
        if (readIndex >= partList.size()) {
            readFromFile();
        }
        return partList.get(readIndex++);
    }

    private void readFromFile() {
        fetchStart = fetchEnd;
        fetchEnd += fetchSize;
        if (fetchEnd >= writeIndex) {
            fetchEnd = writeIndex;
        }
        readIndex = 0;
        partList.clear();
        try {
            int fs = fetchEnd - fetchStart;
            for(int i = 0; i < fs; i++) {
                T object = reader.readObject();
                if (object != null) {
                    partList.add(object);
                } else {
                    break;
                }
//                System.out.println(String.format("fetchStart=%d, i=%d", fetchStart, i));
            }
        } catch (IOException e) {
            e.printStackTrace();
        } catch (IllegalAccessException e) {
            e.printStackTrace();
        } catch (Exception e) {
            throw e;
        }
    }

    public void writeToFile(List<T> list) {
        try {
            for (T t : list) {
                writer.writeObject(t);
            }
            writer.flush();
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }
}

class ObjectWriter<T> {
    private File file;
    private FileOutputStream fos;
    private ByteBuf buff;
    private int max = 0; // 对象最大占用字节数

    private Field[] fields;

    // 变长字段
    private int varLen = 0;

    public ObjectWriter(Class<T> clazz, ByteBuf buff, File file) {
        this.buff = buff;
        this.file = file;
        this.fields = clazz.getDeclaredFields();
        for (Field field : fields) {
            Class<?> type = field.getType();
            if(type == String.class) {
                varLen++;
            }
        }
    }

    public void writeObject(T object) throws IOException {
        // 总长度、变长字段长度、null值字段
        // 8      4 * n        n

        // 总长度
        int len = 4;
        ByteBuf totalLen = buff.slice(buff.writeIndex(), 4);
        buff.writerIndex(buff.writeIndex() + 4);

        // 变长字段
        ByteBuf varByteBuf = buff.slice(buff.writeIndex(), varLen * 4);
        buff.writerIndex(buff.writeIndex() + varLen * 4);
        len += varLen * 4;

        // null值字段
        int fl = fields.length;
        int nullRowSize = Bits.NullRowSizeFor(fl);
        RowByte rowByte = new RowByte(buff.slice(buff.writeIndex(), nullRowSize));
        buff.writerIndex(buff.writeIndex() + nullRowSize);
        len += nullRowSize;

        // 数据字段
        for (int i = 0; i < fields.length; i++) {
            Field field = fields[i];

            Object value = null;
            try {
                field.setAccessible(true);
                value = field.get(object);
            } catch (IllegalAccessException e) {
                e.printStackTrace();
            }

            if(value == null) {
                rowByte.setNull(i);
                continue;

            } else {
                rowByte.setNotNull(i);
            }

            Class<?> type = field.getType();
            if (type == Integer.class) {
                len += 4;
                buff.writeInt((int)value);

            } else if(type == Long.class){
                len += 8;
                buff.writeLong((Long) value);

            } else if(type == String.class) {
                byte[] bytes = value.toString().getBytes("utf-8");
                // 长度
                varByteBuf.writeInt(bytes.length);
//                System.out.println("字符串长度:" + bytes.length);
                // 内容
                buff.writeBytes(bytes);
                // 增加长度
                len += bytes.length;
            }
        }

        // 总长度
        totalLen.writeInt(len);
        max = len > max ? len : max;

//        System.out.println(object.toString() + "  " + "写入长度:" + len);

        // 写入磁盘
        if(max * 2 > buff.capacity() - buff.writeIndex()) {
            flush();
        }
    }

    public void flush() throws IOException {
        byte[] array = buff.array();
        if (fos == null) {
            try {
                this.fos = new FileOutputStream(file);
            } catch (FileNotFoundException e) {
                e.printStackTrace();
            }
        }
        fos.write(array, buff.readIndex(), buff.writeIndex() - buff.readIndex());
        buff.reset();
    }

}


class ObjectReader<T> {
    private FileInputStream fis;
    private ByteBuf buff;
    private Class<T> clazz;
    private Field[] fields;
    // 变长字段
    private int varLen = 0;

    public ObjectReader(Class<T> clazz, ByteBuf buff, File file) {
        this.clazz = clazz;
        this.buff = buff;
        try {
            this.fis = new FileInputStream(file);
        } catch (FileNotFoundException e) {
            e.printStackTrace();
        }

        this.fields = clazz.getDeclaredFields();
        for (Field field : fields) {
            Class<?> type = field.getType();
            if(type == String.class) {
                varLen++;
            }
        }
    }

    public T readObject() throws IOException, IllegalAccessException {
        // 总长度、变长字段长度、null值字段
        // 8      8 * n        n
        T object = null;
        try {
            object = clazz.newInstance();
        } catch (InstantiationException e) {
            e.printStackTrace();
        } catch (IllegalAccessException e) {
            e.printStackTrace();
        }

        if (buff.writeIndex() - buff.readIndex() < 4) {
            fillBuffer();
            if (buff.writeIndex() - buff.readIndex() == 0) {
                return null;
            }
        }

        // 总长度
        ByteBuf totalLen = buff.slice(buff.readIndex(), 4);
        int len = totalLen.getInt();
        buff.readerIndex(buff.readIndex() + 4);

        if (buff.writeIndex() - buff.readIndex() < len - 4) {
            fillBuffer();
        }

        // 变长字段
        ByteBuf varByteBuf = buff.slice(buff.readIndex(), varLen * 4);
        buff.readerIndex(buff.readIndex() + varLen * 4);

        // null值字段
        int fl = fields.length;
        int nullRowSize = Bits.NullRowSizeFor(fl);
        RowByte rowByte = new RowByte(buff.slice(buff.readIndex(), nullRowSize));
        buff.readerIndex(buff.readIndex() + nullRowSize);

        //
        for (int i = 0; i < fields.length; i++) {
            Field field = fields[i];
            if(rowByte.isNull(i)){
                continue;
            }

            Class<?> type = field.getType();
            Object value = null;
            if (type == Integer.class) {
                value = buff.readInt();

            } else if(type == Long.TYPE){
                value = buff.readLong();

            } else if(type == String.class) {
                int strLen = varByteBuf.readInt();
                value = buff.readString(strLen);
            }

            field.setAccessible(true);
            field.set(object, value);
        }

        return object;
    }

    public void fillBuffer() throws IOException {
        buff.discard();

        byte[] array = buff.array();

        int len = fis.read(array, buff.writeIndex(), buff.capacity() - buff.writeIndex());
        buff.writerIndex(buff.writeIndex() + len);
    }

}

interface ByteBuf {

    int capacity();

    int readIndex();

    int writeIndex();

    void writerIndex(int writerIndex);

    void readerIndex(int readerIndex);

    ByteBuf slice(int index, int length);

    // 写
    void writeInt(Integer value);

    void writeLong(Long value);

    void writeBytes(byte[] value);

    void setByte(int index, byte value);

    // 读
    int readInt();

    long readLong();

    String readString(int len);

    int getInt();

    byte getByte(int index);

    // 其它
    byte[] array();

    void reset();

    void discard();
}

class HeapByteBuf implements ByteBuf {
    private int capacity;
    private int readIndex;
    private int writeIndex;
    private byte[] memory;

    public HeapByteBuf(int capacity){
        this.memory = new byte[capacity];
        this.capacity = capacity;
    }

    @Override
    public int capacity() {
        return capacity;
    }

    @Override
    public int readIndex() {
        return readIndex;
    }

    @Override
    public int writeIndex() {
        return writeIndex;
    }

    @Override
    public void writerIndex(int writerIndex) {
        this.writeIndex = writerIndex;
    }

    @Override
    public void readerIndex(int readerIndex) {
        this.readIndex = readerIndex;
    }

    @Override
    public ByteBuf slice(int index, int length) {
        return new SliceByteBuf(this, index, length);
    }

    @Override
    public void writeInt(Integer value) {
        Bits.setInt(memory, writeIndex, value);
        writeIndex += 4;
    }

    @Override
    public void writeLong(Long value) {
        Bits.setLong(memory, writeIndex, value);
        writeIndex += 8;
    }

    @Override
    public void writeBytes(byte[] value) {
        System.arraycopy(value, 0, memory, writeIndex, value.length);
        writeIndex += value.length;
    }

    @Override
    public void setByte(int index, byte value) {
        memory[index] = value;
    }

    @Override
    public int readInt() {
        checkReadIndex(4);
        int value = Bits.getInt(memory, readIndex);
        readIndex += 4;
        return value;
    }

    private void checkReadIndex(int len){
        if (readIndex + len > writeIndex) {
            throw new ArrayIndexOutOfBoundsException(
                    String.format("访问越界, readIndex=%d, writerIndex=%d, len=%d", readIndex, writeIndex, len)
            );
        }
    }

    @Override
    public long readLong() {
        checkReadIndex(8);
        long value = Bits.getLong(memory, readIndex);
        readIndex += 8;
        return value;
    }

    @Override
    public String readString(int len) {
        checkReadIndex(len);
        String value = new String(memory, readIndex, len);
        readIndex += len;
        return value;
    }

    @Override
    public int getInt() {
        checkReadIndex(4);
        int value = Bits.getInt(memory, readIndex);
        return value;
    }

    @Override
    public byte getByte(int index) {
        checkReadIndex(1);
        return memory[index];
    }

    @Override
    public byte[] array() {
        return memory;
    }

    @Override
    public void reset() {
        readIndex = 0;
        writeIndex = 0;
    }

    @Override
    public void discard() {
        System.arraycopy(memory, readIndex, memory, 0, writeIndex - readIndex);
        writeIndex = writeIndex - readIndex;
        readIndex = 0;
    }

}

class SliceByteBuf implements ByteBuf {
    private int capacity;
    private int readIndex;
    private int writeIndex;

    private int index;
    private ByteBuf buf;

    public SliceByteBuf(ByteBuf buf, int index, int length) {
        this.buf = buf;
        this.index = index;
        this.capacity = length;
    }

    private int idx(Integer idx) {
        return this.index + idx;
    }

    @Override
    public int capacity() {
        return this.capacity;
    }

    @Override
    public int readIndex() {
        return readIndex;
    }

    @Override
    public int writeIndex() {
        return writeIndex;
    }

    @Override
    public void writerIndex(int writerIndex) {
        this.writeIndex = writerIndex;
    }

    @Override
    public void readerIndex(int readerIndex) {
        this.readIndex = readerIndex;
    }

    @Override
    public ByteBuf slice(int index, int length) {
        return new SliceByteBuf(this, index, length);
    }

    @Override
    public void writeInt(Integer value) {
        Bits.setInt(buf.array(), idx(readIndex), value);
        readIndex += 4;
    }

    @Override
    public void writeLong(Long value) {
        Bits.setLong(buf.array(), idx(readIndex), value);
        readIndex += 4;
    }

    @Override
    public void writeBytes(byte[] value) {
        System.arraycopy(value, 0, buf.array(), idx(readIndex), value.length);
        readIndex += value.length;
    }

    @Override
    public void setByte(int index, byte value) {
        buf.array()[idx(index)] = value;
    }

    @Override
    public int readInt() {
        int value = Bits.getInt(buf.array(), idx(readIndex));
        readIndex += 4;
        return value;
    }

    @Override
    public long readLong() {
        long value = Bits.getLong(buf.array(), idx(readIndex));
        readIndex += 8;
        return value;
    }

    @Override
    public String readString(int len) {
        String value = new String(buf.array(), idx(readIndex), len);
        readIndex += len;
        return value;
    }

    @Override
    public int getInt() {
        return Bits.getInt(buf.array(), idx(readIndex));
    }

    @Override
    public byte getByte(int index) {
        return buf.array()[idx(index)];
    }

    @Override
    public byte[] array() {
        return buf.array();
    }

    @Override
    public void reset() {
        readIndex = 0;
        writeIndex = 0;
    }

    @Override
    public void discard() {
        System.arraycopy(buf.array(), idx(readIndex), buf.array(), idx(0), writeIndex - readIndex);
        writeIndex = writeIndex - readIndex;
        readIndex = 0;
    }
}

class RowByte{
    private ByteBuf buf;

    public RowByte(ByteBuf buf){
        this.buf = buf;
    }

    public void setNull(int index) {
        int i = index / 8;
        int pos = index % 8;
        byte b = buf.getByte(i);

        byte pb = (byte)(1 << pos);
        b = (byte)(b | pb);

        buf.setByte(i, b);
    }

    public void setNotNull(int index) {
        int i = index / 8;
        int pos = index % 8;

        byte b = buf.getByte(i);

        byte pb = (byte)(1 << pos);
        b = (byte)(b & pb);

        buf.setByte(i, b);
    }

    public boolean isNull(int index){
        int i = index / 8;
        int pos = index % 8;

        byte b = buf.getByte(i);

        byte pb = (byte)(1 << pos);
        byte c = (byte)(b & pb);

        return b != c;
    }

}

class Bits {
    static final int NullRowSizeFor(int cap) {
        int n = cap - 1;
        n |= n >>> 1;
        n |= n >>> 2;
        n |= n >>> 4;
        n |= n >>> 8;
        n |= n >>> 16;
        n = (n < 0) ? 1 : n + 1;
        n = n / 8;
        return n <= 0 ? 1 : n;
    }

    public static void setInt(byte[] memory, int index, int value) {
        memory[index]     = (byte) (value >>> 24);
        memory[index + 1] = (byte) (value >>> 16);
        memory[index + 2] = (byte) (value >>> 8);
        memory[index + 3] = (byte) value;
    }

    public static void setLong(byte[] memory, int index, long value) {
        memory[index]     = (byte) (value >>> 56);
        memory[index + 1] = (byte) (value >>> 48);
        memory[index + 2] = (byte) (value >>> 40);
        memory[index + 3] = (byte) (value >>> 32);
        memory[index + 4] = (byte) (value >>> 24);
        memory[index + 5] = (byte) (value >>> 16);
        memory[index + 6] = (byte) (value >>> 8);
        memory[index + 7] = (byte) value;
    }

    static int getInt(byte[] memory, int index) {
        return  (memory[index]     & 0xff) << 24 |
                (memory[index + 1] & 0xff) << 16 |
                (memory[index + 2] & 0xff) <<  8 |
                memory[index + 3] & 0xff;
    }

    static long getLong(byte[] memory, int index) {
        return  ((long) memory[index]     & 0xff) << 56 |
                ((long) memory[index + 1] & 0xff) << 48 |
                ((long) memory[index + 2] & 0xff) << 40 |
                ((long) memory[index + 3] & 0xff) << 32 |
                ((long) memory[index + 4] & 0xff) << 24 |
                ((long) memory[index + 5] & 0xff) << 16 |
                ((long) memory[index + 6] & 0xff) <<  8 |
                (long) memory[index + 7] & 0xff;
    }
}
©著作权归作者所有,转载或内容合作请联系作者
【社区内容提示】社区部分内容疑似由AI辅助生成,浏览时请结合常识与多方信息审慎甄别。
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

相关阅读更多精彩内容

友情链接更多精彩内容