当我们需要从数据库查询大量数据的时候,我们可以分页查询,然后将没页数据排序后写入磁盘文件,防止内存溢出,然后再从每个文件中取部分数据出来,排序后写入合并文件
package com.ls.sort;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.IOException;
import java.lang.reflect.Field;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.TreeSet;
/**
* 文件排序
*
* @date : 2023/3/17 16:18
**/
public class FileSort<T> {
private List<FileData<T>> fileDatas = new ArrayList<>();
private List<T> list = new ArrayList<>();
private Comparator<T> comparator;
// 多少条数据写一次文件
private int writeToFileNum = 1000;
private ObjectWriter mergeObjectWriter;
public FileSort(Comparator<T> comparator, int writeToFileNum){
this.comparator = comparator;
this.writeToFileNum = writeToFileNum;
}
public void add(T obj) {
list.add(obj);
if (list.size() == writeToFileNum) {
list.sort(comparator);
writeToFile(list);
list.clear();
}
}
private void writeToFile(List<T> list) {
String fileName = "temp_%d.data";
File file = new File(String.format(fileName, fileDatas.size()));
FileData fileData = new FileData(file, list.size(), list.get(0).getClass(), new HeapByteBuf(1024));
fileData.writeToFile(list);
fileDatas.add(fileData);
}
private void writeToMergeFile(List<T> list) {
if (mergeObjectWriter == null) {
mergeObjectWriter = new ObjectWriter(
list.get(0).getClass(), new HeapByteBuf(1024), new File("merge.data"));
}
try {
for (T t : list) {
mergeObjectWriter.writeObject(t);
}
} catch (IOException e) {
e.printStackTrace();
}
}
public void sort() {
int fetchSize = writeToFileNum / fileDatas.size();
// 每次拿取数量
for (FileData<T> fileData : fileDatas) {
fileData.setFetchSize(fetchSize);
}
List<T> list = new ArrayList<>(writeToFileNum);
// 有数据的文件
TreeSet<FileData<T>> treeSet = new TreeSet<>((o1, o2) -> comparator.compare(o1.getNext(), o2.getNext()));
treeSet.addAll(fileDatas);
while (!treeSet.isEmpty()) {
// 拿出最小值
FileData<T> fileData = treeSet.pollFirst();
// 放入最小值
T next = fileData.next();
list.add(next);
System.out.println("排序后:" + next);
// 写文件
if (list.size() == writeToFileNum) {
writeToMergeFile(list);
list.clear();
}
// 还有数据
if (fileData.hasNext()) {
treeSet.add(fileData);
}
}
}
}
class FileData<T> {
private ObjectWriter<T> writer;
private ObjectReader<T> reader;
private List<T> partList = new ArrayList<>();
private Integer fetchStart = 0;
private Integer fetchEnd = 0;
private Integer readIndex = 0;
private Integer writeIndex = 0;
private Integer fetchSize;
public FileData(File file, Integer dataNum, Class<T> clazz, ByteBuf buf){
this.writeIndex = dataNum;
this.reader = new ObjectReader(clazz, buf, file);
this.writer = new ObjectWriter<>(clazz, buf, file);
}
public void setFetchSize(Integer fetchSize) {
this.fetchSize = fetchSize;
}
public boolean hasNext() {
return fetchStart + readIndex < writeIndex;
}
public T getNext() {
if (readIndex >= partList.size()) {
readFromFile();
}
return partList.get(readIndex);
}
public T next() {
if (readIndex >= partList.size()) {
readFromFile();
}
return partList.get(readIndex++);
}
private void readFromFile() {
fetchStart = fetchEnd;
fetchEnd += fetchSize;
if (fetchEnd >= writeIndex) {
fetchEnd = writeIndex;
}
readIndex = 0;
partList.clear();
try {
int fs = fetchEnd - fetchStart;
for(int i = 0; i < fs; i++) {
T object = reader.readObject();
if (object != null) {
partList.add(object);
} else {
break;
}
// System.out.println(String.format("fetchStart=%d, i=%d", fetchStart, i));
}
} catch (IOException e) {
e.printStackTrace();
} catch (IllegalAccessException e) {
e.printStackTrace();
} catch (Exception e) {
throw e;
}
}
public void writeToFile(List<T> list) {
try {
for (T t : list) {
writer.writeObject(t);
}
writer.flush();
} catch (IOException e) {
throw new RuntimeException(e);
}
}
}
class ObjectWriter<T> {
private File file;
private FileOutputStream fos;
private ByteBuf buff;
private int max = 0; // 对象最大占用字节数
private Field[] fields;
// 变长字段
private int varLen = 0;
public ObjectWriter(Class<T> clazz, ByteBuf buff, File file) {
this.buff = buff;
this.file = file;
this.fields = clazz.getDeclaredFields();
for (Field field : fields) {
Class<?> type = field.getType();
if(type == String.class) {
varLen++;
}
}
}
public void writeObject(T object) throws IOException {
// 总长度、变长字段长度、null值字段
// 8 4 * n n
// 总长度
int len = 4;
ByteBuf totalLen = buff.slice(buff.writeIndex(), 4);
buff.writerIndex(buff.writeIndex() + 4);
// 变长字段
ByteBuf varByteBuf = buff.slice(buff.writeIndex(), varLen * 4);
buff.writerIndex(buff.writeIndex() + varLen * 4);
len += varLen * 4;
// null值字段
int fl = fields.length;
int nullRowSize = Bits.NullRowSizeFor(fl);
RowByte rowByte = new RowByte(buff.slice(buff.writeIndex(), nullRowSize));
buff.writerIndex(buff.writeIndex() + nullRowSize);
len += nullRowSize;
// 数据字段
for (int i = 0; i < fields.length; i++) {
Field field = fields[i];
Object value = null;
try {
field.setAccessible(true);
value = field.get(object);
} catch (IllegalAccessException e) {
e.printStackTrace();
}
if(value == null) {
rowByte.setNull(i);
continue;
} else {
rowByte.setNotNull(i);
}
Class<?> type = field.getType();
if (type == Integer.class) {
len += 4;
buff.writeInt((int)value);
} else if(type == Long.class){
len += 8;
buff.writeLong((Long) value);
} else if(type == String.class) {
byte[] bytes = value.toString().getBytes("utf-8");
// 长度
varByteBuf.writeInt(bytes.length);
// System.out.println("字符串长度:" + bytes.length);
// 内容
buff.writeBytes(bytes);
// 增加长度
len += bytes.length;
}
}
// 总长度
totalLen.writeInt(len);
max = len > max ? len : max;
// System.out.println(object.toString() + " " + "写入长度:" + len);
// 写入磁盘
if(max * 2 > buff.capacity() - buff.writeIndex()) {
flush();
}
}
public void flush() throws IOException {
byte[] array = buff.array();
if (fos == null) {
try {
this.fos = new FileOutputStream(file);
} catch (FileNotFoundException e) {
e.printStackTrace();
}
}
fos.write(array, buff.readIndex(), buff.writeIndex() - buff.readIndex());
buff.reset();
}
}
class ObjectReader<T> {
private FileInputStream fis;
private ByteBuf buff;
private Class<T> clazz;
private Field[] fields;
// 变长字段
private int varLen = 0;
public ObjectReader(Class<T> clazz, ByteBuf buff, File file) {
this.clazz = clazz;
this.buff = buff;
try {
this.fis = new FileInputStream(file);
} catch (FileNotFoundException e) {
e.printStackTrace();
}
this.fields = clazz.getDeclaredFields();
for (Field field : fields) {
Class<?> type = field.getType();
if(type == String.class) {
varLen++;
}
}
}
public T readObject() throws IOException, IllegalAccessException {
// 总长度、变长字段长度、null值字段
// 8 8 * n n
T object = null;
try {
object = clazz.newInstance();
} catch (InstantiationException e) {
e.printStackTrace();
} catch (IllegalAccessException e) {
e.printStackTrace();
}
if (buff.writeIndex() - buff.readIndex() < 4) {
fillBuffer();
if (buff.writeIndex() - buff.readIndex() == 0) {
return null;
}
}
// 总长度
ByteBuf totalLen = buff.slice(buff.readIndex(), 4);
int len = totalLen.getInt();
buff.readerIndex(buff.readIndex() + 4);
if (buff.writeIndex() - buff.readIndex() < len - 4) {
fillBuffer();
}
// 变长字段
ByteBuf varByteBuf = buff.slice(buff.readIndex(), varLen * 4);
buff.readerIndex(buff.readIndex() + varLen * 4);
// null值字段
int fl = fields.length;
int nullRowSize = Bits.NullRowSizeFor(fl);
RowByte rowByte = new RowByte(buff.slice(buff.readIndex(), nullRowSize));
buff.readerIndex(buff.readIndex() + nullRowSize);
//
for (int i = 0; i < fields.length; i++) {
Field field = fields[i];
if(rowByte.isNull(i)){
continue;
}
Class<?> type = field.getType();
Object value = null;
if (type == Integer.class) {
value = buff.readInt();
} else if(type == Long.TYPE){
value = buff.readLong();
} else if(type == String.class) {
int strLen = varByteBuf.readInt();
value = buff.readString(strLen);
}
field.setAccessible(true);
field.set(object, value);
}
return object;
}
public void fillBuffer() throws IOException {
buff.discard();
byte[] array = buff.array();
int len = fis.read(array, buff.writeIndex(), buff.capacity() - buff.writeIndex());
buff.writerIndex(buff.writeIndex() + len);
}
}
interface ByteBuf {
int capacity();
int readIndex();
int writeIndex();
void writerIndex(int writerIndex);
void readerIndex(int readerIndex);
ByteBuf slice(int index, int length);
// 写
void writeInt(Integer value);
void writeLong(Long value);
void writeBytes(byte[] value);
void setByte(int index, byte value);
// 读
int readInt();
long readLong();
String readString(int len);
int getInt();
byte getByte(int index);
// 其它
byte[] array();
void reset();
void discard();
}
class HeapByteBuf implements ByteBuf {
private int capacity;
private int readIndex;
private int writeIndex;
private byte[] memory;
public HeapByteBuf(int capacity){
this.memory = new byte[capacity];
this.capacity = capacity;
}
@Override
public int capacity() {
return capacity;
}
@Override
public int readIndex() {
return readIndex;
}
@Override
public int writeIndex() {
return writeIndex;
}
@Override
public void writerIndex(int writerIndex) {
this.writeIndex = writerIndex;
}
@Override
public void readerIndex(int readerIndex) {
this.readIndex = readerIndex;
}
@Override
public ByteBuf slice(int index, int length) {
return new SliceByteBuf(this, index, length);
}
@Override
public void writeInt(Integer value) {
Bits.setInt(memory, writeIndex, value);
writeIndex += 4;
}
@Override
public void writeLong(Long value) {
Bits.setLong(memory, writeIndex, value);
writeIndex += 8;
}
@Override
public void writeBytes(byte[] value) {
System.arraycopy(value, 0, memory, writeIndex, value.length);
writeIndex += value.length;
}
@Override
public void setByte(int index, byte value) {
memory[index] = value;
}
@Override
public int readInt() {
checkReadIndex(4);
int value = Bits.getInt(memory, readIndex);
readIndex += 4;
return value;
}
private void checkReadIndex(int len){
if (readIndex + len > writeIndex) {
throw new ArrayIndexOutOfBoundsException(
String.format("访问越界, readIndex=%d, writerIndex=%d, len=%d", readIndex, writeIndex, len)
);
}
}
@Override
public long readLong() {
checkReadIndex(8);
long value = Bits.getLong(memory, readIndex);
readIndex += 8;
return value;
}
@Override
public String readString(int len) {
checkReadIndex(len);
String value = new String(memory, readIndex, len);
readIndex += len;
return value;
}
@Override
public int getInt() {
checkReadIndex(4);
int value = Bits.getInt(memory, readIndex);
return value;
}
@Override
public byte getByte(int index) {
checkReadIndex(1);
return memory[index];
}
@Override
public byte[] array() {
return memory;
}
@Override
public void reset() {
readIndex = 0;
writeIndex = 0;
}
@Override
public void discard() {
System.arraycopy(memory, readIndex, memory, 0, writeIndex - readIndex);
writeIndex = writeIndex - readIndex;
readIndex = 0;
}
}
class SliceByteBuf implements ByteBuf {
private int capacity;
private int readIndex;
private int writeIndex;
private int index;
private ByteBuf buf;
public SliceByteBuf(ByteBuf buf, int index, int length) {
this.buf = buf;
this.index = index;
this.capacity = length;
}
private int idx(Integer idx) {
return this.index + idx;
}
@Override
public int capacity() {
return this.capacity;
}
@Override
public int readIndex() {
return readIndex;
}
@Override
public int writeIndex() {
return writeIndex;
}
@Override
public void writerIndex(int writerIndex) {
this.writeIndex = writerIndex;
}
@Override
public void readerIndex(int readerIndex) {
this.readIndex = readerIndex;
}
@Override
public ByteBuf slice(int index, int length) {
return new SliceByteBuf(this, index, length);
}
@Override
public void writeInt(Integer value) {
Bits.setInt(buf.array(), idx(readIndex), value);
readIndex += 4;
}
@Override
public void writeLong(Long value) {
Bits.setLong(buf.array(), idx(readIndex), value);
readIndex += 4;
}
@Override
public void writeBytes(byte[] value) {
System.arraycopy(value, 0, buf.array(), idx(readIndex), value.length);
readIndex += value.length;
}
@Override
public void setByte(int index, byte value) {
buf.array()[idx(index)] = value;
}
@Override
public int readInt() {
int value = Bits.getInt(buf.array(), idx(readIndex));
readIndex += 4;
return value;
}
@Override
public long readLong() {
long value = Bits.getLong(buf.array(), idx(readIndex));
readIndex += 8;
return value;
}
@Override
public String readString(int len) {
String value = new String(buf.array(), idx(readIndex), len);
readIndex += len;
return value;
}
@Override
public int getInt() {
return Bits.getInt(buf.array(), idx(readIndex));
}
@Override
public byte getByte(int index) {
return buf.array()[idx(index)];
}
@Override
public byte[] array() {
return buf.array();
}
@Override
public void reset() {
readIndex = 0;
writeIndex = 0;
}
@Override
public void discard() {
System.arraycopy(buf.array(), idx(readIndex), buf.array(), idx(0), writeIndex - readIndex);
writeIndex = writeIndex - readIndex;
readIndex = 0;
}
}
class RowByte{
private ByteBuf buf;
public RowByte(ByteBuf buf){
this.buf = buf;
}
public void setNull(int index) {
int i = index / 8;
int pos = index % 8;
byte b = buf.getByte(i);
byte pb = (byte)(1 << pos);
b = (byte)(b | pb);
buf.setByte(i, b);
}
public void setNotNull(int index) {
int i = index / 8;
int pos = index % 8;
byte b = buf.getByte(i);
byte pb = (byte)(1 << pos);
b = (byte)(b & pb);
buf.setByte(i, b);
}
public boolean isNull(int index){
int i = index / 8;
int pos = index % 8;
byte b = buf.getByte(i);
byte pb = (byte)(1 << pos);
byte c = (byte)(b & pb);
return b != c;
}
}
class Bits {
static final int NullRowSizeFor(int cap) {
int n = cap - 1;
n |= n >>> 1;
n |= n >>> 2;
n |= n >>> 4;
n |= n >>> 8;
n |= n >>> 16;
n = (n < 0) ? 1 : n + 1;
n = n / 8;
return n <= 0 ? 1 : n;
}
public static void setInt(byte[] memory, int index, int value) {
memory[index] = (byte) (value >>> 24);
memory[index + 1] = (byte) (value >>> 16);
memory[index + 2] = (byte) (value >>> 8);
memory[index + 3] = (byte) value;
}
public static void setLong(byte[] memory, int index, long value) {
memory[index] = (byte) (value >>> 56);
memory[index + 1] = (byte) (value >>> 48);
memory[index + 2] = (byte) (value >>> 40);
memory[index + 3] = (byte) (value >>> 32);
memory[index + 4] = (byte) (value >>> 24);
memory[index + 5] = (byte) (value >>> 16);
memory[index + 6] = (byte) (value >>> 8);
memory[index + 7] = (byte) value;
}
static int getInt(byte[] memory, int index) {
return (memory[index] & 0xff) << 24 |
(memory[index + 1] & 0xff) << 16 |
(memory[index + 2] & 0xff) << 8 |
memory[index + 3] & 0xff;
}
static long getLong(byte[] memory, int index) {
return ((long) memory[index] & 0xff) << 56 |
((long) memory[index + 1] & 0xff) << 48 |
((long) memory[index + 2] & 0xff) << 40 |
((long) memory[index + 3] & 0xff) << 32 |
((long) memory[index + 4] & 0xff) << 24 |
((long) memory[index + 5] & 0xff) << 16 |
((long) memory[index + 6] & 0xff) << 8 |
(long) memory[index + 7] & 0xff;
}
}