Optional Challenge
The user-level thread package interacts badly with the operating system in several ways. For example, if one user-level thread blocks in a system call, another user-level thread won't run, because the user-level threads scheduler doesn't know that one of its threads has been descheduled by the xv6 scheduler. As another example, two user-level threads will not run concurrently on different cores, because the xv6 scheduler isn't aware that there are multiple threads that could run in parallel. Note that if two user-level threads were to run truly in parallel, this implementation won't work because of several races (e.g., two threads on different processors could call <tt>thread_schedule</tt> concurrently, select the same runnable thread, and both run it on different processors.)
There are several ways of addressing these problems. One is using scheduler activations and another is to use one kernel thread per user-level thread (as Linux kernels do). Implement one of these ways in xv6. This is not easy to get right; for example, you will need to implement TLB shootdown when updating a page table for a multithreaded user process.
参考: https://courses.cs.duke.edu/fall23/compsci310/thread.html
为了实现上述功能,首先我们需要定义2个新的系统调用函数。
类似于进程里的fork 和 wait
一个是用来开一个线程,另一个是用来等待线程结束。
大概细节如下:
API Details
We describe the API here, including its input parameters, what it does, and its return value.
1 clone实现
主要是借鉴fork函数的框架 进行改写。
int
clone(void(*fcn)(void*, void*), void *arg1, void *arg2, void *stack)
{
int i, pid;
struct proc *np;
struct proc *p = myproc();
// Ensure stack is page align, which help setup guard page.
if(((uint64)stack % PGSIZE) != 0)
return -1;
// Allocate process.
if((np = allocproc()) == 0){
return -1;
}
// mark this process is thread
np->isthread = 1;
// use same page table as parent, to keep same memory space
np->pagetable = p->pagetable;
// share some variable between threads
np->tshared = p->tshared;
np->parent = p;
// copy saved user registers.
*(np->trapframe) = *(p->trapframe);
// setup thread's function address
np->trapframe->epc = (uint64)fcn;
// setup thread's function args
// refer to riscv calling covention: https://pdos.csail.mit.edu/6.828/2023/readings/riscv-calling.pdf
np->trapframe->a0 = (uint64)arg1;
np->trapframe->a1 = (uint64)arg2;
// ensure thread without exit return to a invalid address to trigger trap
np->trapframe->ra = 0xffffffffffffffff;
// Use the second page as the user stack.
np->trapframe->sp = (uint64)(stack + 2 * PGSIZE);
// Keep stack address for "join" to return
np->tstack = (uint64)stack;
// setup first stack page as guard page, remove PTE_U
uvmclear(np->pagetable, np->tstack);
// find a address to remap TRAPFRAME page
// it is important since TRAPFRAME page should not be shared across threads
uint64 trap_va = PHYSTOP;
for(; trap_va < TRAPFRAME ; trap_va += PGSIZE) {
if (kwalkaddr(np->pagetable, trap_va) == 0) {
np->trap_va = trap_va;
mappages(np->pagetable, np->trap_va, PGSIZE,
(uint64)(np->trapframe), PTE_R | PTE_W);
break;
}
}
// failed to find a space
if (trap_va >= TRAPFRAME) {
return -1;
}
// increment reference counts on open file descriptors.
for(i = 0; i < NOFILE; i++)
if(p->ofile[i])
np->ofile[i] = filedup(p->ofile[i]);
np->cwd = idup(p->cwd);
safestrcpy(np->name, p->name, sizeof(p->name));
pid = np->pid;
np->state = RUNNABLE;
release(&np->lock);
return pid;
}
2 改动proc.h
然后我们需要对proc 结构进行一些改写。
3 支持动态TRAPFRAME
然后因为,不同的thread TRAPFRAME 不再是一个常数地址,所以我们需要记录这个变量,并且需要更改相应的trap 汇编,改动如下:
((void (*)(uint64,uint64))trampoline_userret)(p->trap_va, satp);
4 使得sz在thread 间共享
因为thread彼此之间是共享内存空间的,所以当有一个线程,增大了内存,应该对其他线程可见。
我们之前在proc 结构体已经通过指针的方式,把sz存进到TRAMPFRAME页面的最后。并且在clone函数里使得所有thread 这个指针指向了父亲的TRAMPFRAME页上这个变量的地址。
那么我们需要在更改这个值的时候,使用指针修改就可以使得其他线程看见最新的sz. 同时我们为了防止几个线程同时修改sz, 出现更是丢失的情况,我们需要用同一把锁,对修改进行加锁操作。
更改相应的growproc
函数:
int
growproc(int n)
{
uint64 sz;
struct proc *p = myproc();
acquire(&p->tshared->tlock);
sz = p->tshared->sz;
if(n > 0){
if((sz = uvmalloc(p->pagetable, sz, sz + n, PTE_W)) == 0) {
release(&p->tshared->tlock);
return -1;
}
} else if(n < 0){
sz = uvmdealloc(p->pagetable, sz, sz + n);
}
p->tshared->sz = sz;
release(&p->tshared->tlock);
return 0;
}
同时更改其他所有p->sz
为p->tshared->sz
5 支持 kwalkaddr
因为我们需要在kernel 的页表里,找到一页动态映射到线程的TRAPFRAME上。所以借鉴walkaddr, 实现一个kernel level的
uint64
kwalkaddr(pagetable_t pagetable, uint64 va)
{
pte_t *pte;
uint64 pa;
if(va >= MAXVA)
return 0;
pte = walk(pagetable, va, 0);
if(pte == 0)
return 0;
if((*pte & PTE_V) == 0)
return 0;
pa = PTE2PA(*pte);
return pa;
}
第一个thread测试
这里借鉴了杜克大学写的一个测试,如果完成了clone实现,应该test1 可以跑过,但是会在exit的时候抛错。因为我们那边代码还没改完。
#include "kernel/types.h"
#include "user.h"
#undef NULL
#define NULL ((void*)0)
#define PGSIZE (4096)
int ppid;
int global = 0;
int res1 = 0;
int res2 = 0;
#define assert(x) if (x) {} else { \
printf("%s: %d ", __FILE__, __LINE__); \
printf("assert failed (%s)\n", # x); \
printf("TEST FAILED\n"); \
kill(ppid); \
exit(1); \
}
void
exittest(void *arg1, void * arg2){
int int1 = *(int*)arg1;
int int2 = *(int*)arg2;
res1 = int1;
res2 = int2;
// while(1){;}
exit(0);
}
void
emptytest(void *arg1, void* arg2) {
// int i;
int int1 = *(int*)arg1;
int int2 = *(int*)arg2;
int1 = int2 + int1;
// assert(getpid() == ppid);
exit(0);
}
void sbrktest(void* arg1, void* arg2) {
char* b = sbrk(65536);
// printf("sbrk end\n");
for (int i = 0; i < 4096000; i++) {
b[i % 65536] = 0;
}
exit(0);
}
void threadinthread(void* arg1, void* arg2) {
int int1 = *(int*) arg1;
if (int1 == 1234) {
// create a new thread
int a1 = 0, a2 = 0;
int threadid = thread_create(threadinthread, &a1, &a2);
assert(threadid > ppid);
}
for (int i = 0; i < 4096000; i++) {
int1++;
}
while(1);
exit(0);
}
void
stacktest(void *arg1, void* arg2) {
int int1 = *(int*)arg1;
int int2 = *(int*)arg2;
assert(int1 == 1);
assert(int2 == 2);
int1 = int2 + int1;
assert(int1 == 3);
exit(0);
}
void
heaptest(void *arg1, void* arg2) {
int int1 = *(int*)arg1;
int int2 = *(int*)arg2;
assert(int1 == 1);
assert(int2 == 2);
assert(global == 0);
global++;
assert(global == 1);
exit(0);
}
//test1: thread create function
int test1(){
uint64 arg1 = 1;
uint64 arg2 = 2;
int thread_pid1 = thread_create(emptytest, &arg1, &arg2);
int thread_pid2 = thread_create(emptytest, &arg1, &arg2);
assert(thread_pid1 > ppid);
assert(thread_pid2 > ppid);
printf("TEST1 PASSED\n");
return 0;
}
//test2: thread join function
int test2(){
int join_pid = thread_join();
assert(join_pid > 0);
join_pid = thread_join();
assert(join_pid > 0);
printf("TEST2 PASSED\n");
return 0;
}
//test3: shared address space
int test3(){
uint64 arg1 = 1;
uint64 arg2 = 2;
int thread_pid1 = thread_create(stacktest, &arg1, &arg2);
int thread_pid2 = thread_create(heaptest, &arg1, &arg2);
assert(thread_pid1 > 0);
assert(thread_pid2 > 0);
int join_pid = thread_join();
assert(join_pid > 0);
join_pid = thread_join();
assert(join_pid > 0);
assert(arg1 == 1);
assert(arg2 == 2);
assert(global == 1);
printf("TEST3 PASSED\n");
return 0;
}
//test4: wait/exit
int test4(){
int pid = fork();
if(pid == 0){
ppid = getpid();
uint64 arg1 = 1;
uint64 arg2 = 2;
int thread_pid1 = thread_create(exittest, &arg1, &arg2);
int thread_pid2 = thread_create(exittest, &arg1, &arg2);
assert(thread_pid1 > 0);
assert(thread_pid2 > 0);
int join_pid = thread_join();
assert(join_pid > 0);
join_pid = thread_join();
assert(join_pid > 0);
assert(res1 == 1);
assert(res2 == 2);
assert(global == 1);
exit(0);
}
else{
int status;
wait(&status);
assert(status == 0);
assert(res1 == 0);
assert(res2 == 0);
printf("TEST4 PASSED\n");
return 0;
}
}
//test5: shared size
int test5() {
int thread_pid1 = thread_create(sbrktest, 0, 0);
int thread_pid2 = thread_create(sbrktest, 0, 0);
assert(thread_pid1 > 0);
assert(thread_pid2 > 0);
thread_join();
thread_join();
printf("TEST5 PASSED\n");
return 0;
}
//test6: thread in thread
int test6() {
int pid = fork();
if (pid == 0) {
int arg1 = 1234;
int thread_pid1 = thread_create(threadinthread, &arg1, 0);
sleep(20);
assert(thread_pid1 > ppid);
exit(0);
} else {
wait(0);
printf("TEST6 PASSED\n");
}
return 0;
}
int
main(int argc, char *argv[])
{
ppid = getpid();
test1();
test2();
test3();
test4();
test5();
test6();
exit(0);
}
7 完成胶水函数
为了能够让编译通过,我们需要把一些框架性的代码给完成。可以先return 0;
sysproc.c
uint64
sys_clone(void)
{
uint64 fcn, arg1, arg2, stack;
argaddr(0, &fcn);
argaddr(1, &arg1);
argaddr(2, &arg2);
argaddr(3, &stack);
return clone((void *)fcn, (void *)arg1, (void *)arg2, (void *)stack);
}
uint64
sys_join(void)
{
uint64 stack;
argaddr(0, &stack);
return join((void **)stack);
}
ulib.c
int thread_create(void (*start_routine)(void *, void *), void *arg1, void *arg2)
{
void *stack = malloc(3 * 4096);
uint64 addr = PGROUNDUP((uint64) stack);
return clone(start_routine, arg1, arg2, (void *)addr);
}
int thread_join(){
return 0;
}
测试效果:
8 修正exit
根据之前课程我们可以知道exit
负责设置退出状态,之后会由父进程再wait的时候,去释放资源。
上面的错误是因为,我们任何一个线程退出的时候,会释放内存空间。因为
那么其他线程正在运行的时候,pagetable就会错乱。所以我们需要保证,在线程退出的时候不要释放资源。
我们的实现方案是,只有当第一个进程退出时,进行释放资源。其他线程基于这个进程创建出来不释放资源。
如果进程最先退出,他会KILL掉其他所有线程,然后等待他们完成,自己再退出。
这块代码可以借鉴freeproc
or kill
; 我们来看下exit的实现改动:
8.1 线程exit
首先reparent这一步,我们可以只要进程去做就可以了。原因线程的孩子线程不需要交给init进程去wait 去释放资源。
// Pass p's abandoned children to init.
// Caller must hold wait_lock.
void
reparent(struct proc *p)
{
struct proc *pp;
for(pp = proc; pp < &proc[NPROC]; pp++){
if(pp->parent == p && !pp->isthread){
pp->parent = initproc;
wakeup(initproc);
}
}
}
然后在exit里,释放掉clone创建出的trap_va
p->xstate = status;
p->state = ZOMBIE;
// unmap since we map trap_va in join
if (p->isthread) {
uvmunmap(p->pagetable, p->trap_va, 1, 0);
}
wait这边,只考虑进程,而非线程
// Wait for a child process to exit and return its pid.
// Return -1 if this process has no children.
int
wait(uint64 addr)
{
struct proc *pp;
int havekids, pid;
struct proc *p = myproc();
acquire(&wait_lock);
for(;;){
// Scan through table looking for exited children.
havekids = 0;
for(pp = proc; pp < &proc[NPROC]; pp++){
// wait only consider process
if(pp->parent == p && !pp->isthread){
freeproc 线程不需要free pagetable
// free a proc structure and the data hanging from it,
// including user pages.
// p->lock must be held.
static void
freeproc(struct proc *p)
{
// p->tshared->sz in p->trapframe, so move proc_freepagetable before
if(p->pagetable && !p->isthread)
proc_freepagetable(p->pagetable, p->tshared->sz);
p->pagetable = 0;
if(p->trapframe)
kfree((void*)p->trapframe);
p->trapframe = 0;
p->tshared = 0;
p->pid = 0;
p->parent = 0;
p->name[0] = 0;
p->chan = 0;
p->killed = 0;
p->xstate = 0;
p->tstack = 0;
p->isthread = 0;
p->trap_va = 0;
p->state = UNUSED;
}
8.2 进程exit
// when a process (not a thread) calls exit, all threads of this process should be exit
void
tpkill(struct proc *curproc)
{
struct proc *p;
int havethreads;
acquire(&wait_lock);
// make all the threads in group to die (all process with same pid will be killed)
for(p = proc; p < &proc[NPROC]; p++){
if(p->parent == curproc && p->isthread){
acquire(&p->lock);
p->killed = 1;
if(p->state == SLEEPING) p->state = RUNNABLE;
release(&p->lock);
}
}
// now let all the threads finish and wait for them become zombie
for(;;){
havethreads = 0;
for(p = proc; p < &proc[NPROC]; p++){
if(p->parent != curproc || !p->isthread) continue;
// thread in group is not died yet so suspend untill it dies.
if(p->state != ZOMBIE){
havethreads = 1;
break;
} else {
acquire(&p->lock);
freeproc(p);
release(&p->lock);
}
}
// group leader doesn't have any threads
if(!havethreads){
break;
}
// sleep for an exisiting thread in group to be killed
sleep(curproc, &wait_lock);
}
release(&wait_lock);
}
然后如果是线程,就要看看它下面有没有线程还活着。
// Exit the current process. Does not return.
// An exited process remains in the zombie state
// until its parent calls wait().
void
exit(int status)
{
struct proc *p = myproc();
if(p == initproc)
panic("init exiting");
// Close all open files.
for(int fd = 0; fd < NOFILE; fd++){
if(p->ofile[fd]){
struct file *f = p->ofile[fd];
fileclose(f);
p->ofile[fd] = 0;
}
}
begin_op();
iput(p->cwd);
end_op();
p->cwd = 0;
tpkill(p);
...
}
test1 通过
9 实现 join
基本照抄wait函数,针对isthread 做一些修改
int
join(void **stack)
{
struct proc *pp;
int havekids, pid;
struct proc *p = myproc();
acquire(&wait_lock);
for(;;){
// Scan through table looking for exited children.
havekids = 0;
for(pp = proc; pp < &proc[NPROC]; pp++){
if(pp->parent == p && pp->isthread){
acquire(&pp->lock);
havekids = 1;
if(pp->state == ZOMBIE){
pid = pp->pid;
if(stack != 0 && copyout(p->pagetable, (uint64)stack, (char *)&pp->tstack,
sizeof(pp->tstack)) < 0) {
release(&pp->lock);
release(&wait_lock);
return -1;
}
// reset guard page with PTE_U
uvmset(pp->pagetable, pp->tstack);
freeproc(pp);
release(&pp->lock);
release(&wait_lock);
return pid;
}
release(&pp->lock);
}
}
if(!havekids || p->killed){
release(&wait_lock);
return -1;
}
sleep(p, &wait_lock);
}
}
我们之前在 clone时设置的guard page,因为之后要还给用户态的内存使用,所以需要把PTE_U重新设置上
void
uvmset(pagetable_t pagetable, uint64 va)
{
pte_t *pte;
pte = walk(pagetable, va, 0);
if(pte == 0)
panic("uvmclear");
*pte |= PTE_U;
}
int thread_join(){
void *stack;
int pid = join(&stack);
free(stack);
return pid;
}
不过还存在一个问题就是,free需要拿到的是malloc 分配的起始地址,但是我们在malloc 时,做了一个PGROUNDUP. 那么free其实没法真正的去free之前malloc的内存。
10 新增malloc_align
为了解决上述问题,我们需要修改下umalloc.c
, 增加一个malloc_align的函数。他会帮助找到一个4096对齐的地址空间,然后返回。
void*
malloc_align(uint oribytes)
{
// [header| ] <- [header| ] <- [header| ] <- [header| ]
// p ret right prevp
Header *p, *prevp, *ret, *right;
uint nunits, ounits;
// we need a larger block because of align requirement
uint nbytes = oribytes + 4096;
ounits = (oribytes + sizeof(Header) - 1)/sizeof(Header) + 1;
nunits = (nbytes + sizeof(Header) - 1)/sizeof(Header) + 1;
if((prevp = freep) == 0){
base.s.ptr = freep = prevp = &base;
base.s.size = 0;
}
for(p = prevp->s.ptr; ; prevp = p, p = p->s.ptr){
if(p->s.size >= nunits){
uint64 paddr = (uint64) p;
uint64 align_addr = PGROUNDUP(paddr + sizeof(Header)) - sizeof(Header);
uint sz = (align_addr - paddr)/sizeof(Header), psz = p->s.size;
ret = (Header *)align_addr;
ret->s.size = ounits;
right = ret + ounits;
right->s.size = psz - ounits - sz;
if (sz == 0) {
right->s.ptr = p->s.ptr;
} else {
right->s.ptr = p;
p->s.size = sz;
}
prevp->s.ptr = right;
freep = prevp;
return (void*)(ret + 1);
}
if(p == freep)
if((p = morecore(nunits)) == 0)
return 0;
}
}
然后在thread_create 使用malloc_align 去分配对齐的内存。
int thread_create(void (*start_routine)(void *, void *), void *arg1, void *arg2)
{
return clone(start_routine, arg1, arg2, malloc_align(8192));
}
duke 的6个测试全部通过
11 用户级别的锁
我们之前用的spinlock 是kernel 层面的。但是我们现在在用户态可以进行多线程编程了,所以我们需要支持用户级别的锁。来保证线程安全。比如我们再调用malloc_align时,如果2个线程一起操作势必会出问题,所以我们需要用锁来保护。
首先我们实现一个原子的读之后增加。
static inline int fetch_and_add(int* variable, int value) {
int result;
asm volatile (
"amoadd.w %0, %2, (%1)"
: "=r" (result)
: "r" (variable), "r" (value)
: "memory"
);
return result;
}
然后锁里有2个变量,一个是获取到的锁的ticket, 然后当前锁着的turn
typedef struct _lock_t {
int ticket;
int turn;
} lock_t;
比如第一个线程上锁,拿到ticket = 0, turn = 0;
第二个线程尝试上同一把锁,拿到ticket = 1, turn = 0; 开始spin等待,turn = 1;
第一个线程释放锁, turn = turn + 1;
第二线程spin等待退出。
void lock_init(lock_t *lock) {
lock->ticket = 0;
lock->turn = 0;
}
void lock_acquire(lock_t *lock) {
int myturn = fetch_and_add(&lock->ticket, 1);
while( fetch_and_add(&lock->turn, 0) != myturn ) {
;
}
}
void lock_release(lock_t *lock) {
lock->turn = lock->turn + 1;
}
然后对thread_create
进行上锁保护
int thread_create(void (*start_routine)(void *, void *), void *arg1, void *arg2)
{
lock_acquire(&thread_create_lock);
void *stack = malloc_align(8192);
lock_release(&thread_create_lock);
return clone(start_routine, arg1, arg2, stack);
}
12 更多的测试
clone, join 测试9个
static inline uint64
xchg(volatile uint64 *addr, uint64 newval) {
uint64 result;
uint64 temp;
asm volatile (
"1: lr.d %0, %2 \n" // Load Reserved from addr
" mv %1, %3 \n" // Move newval to temp
" sc.d %1, %1, %2 \n" // Store Conditional temp to addr
" bnez %1, 1b \n" // If sc.d failed, retry
: "=&r" (result), "=&r" (temp), "+A" (*addr)
: "r" (newval)
: "memory");
return result;
}
#include "kernel/types.h"
#include "user/user.h"
#include "kernel/fcntl.h"
#include "kernel/riscv.h"
#undef NULL
#define NULL ((void*)0)
int ppid;
volatile int arg1 = 11;
volatile int arg2 = 22;
volatile int global = 1;
volatile uint64 newfd = 0;
#define assert(x) if (x) {} else { \
printf("%s: %d ", __FILE__, __LINE__); \
printf("assert failed (%s)\n", # x); \
printf("TEST FAILED\n"); \
kill(ppid); \
exit(0); \
}
void worker(void *arg1, void *arg2);
void worker2(void *arg1, void *arg2);
void worker3(void *arg1, void *arg2);
void worker4(void *arg1, void *arg2);
void worker5(void *arg1, void *arg2);
void worker6(void *arg1, void *arg2);
/* clone and verify that address space is shared */
void test1(void *stack)
{
int clone_pid = clone(worker, 0, 0, stack);
assert(clone_pid > 0);
while(global != 5);
printf("TEST1 PASSED\n");
void *join_stack;
int join_pid = join(&join_stack);
assert(join_pid == clone_pid);
global = 1;
}
/* clone and play with the argument */
void test2(void *stack)
{
int clone_pid = clone(worker2, (void*)&arg1, (void*)&arg2, stack);
assert(clone_pid > 0);
while(global != 33);
assert(arg1 == 44);
assert(arg2 == 55);
printf("TEST2 PASSED\n");
void *join_stack;
int join_pid = join(&join_stack);
assert(join_pid == clone_pid);
}
/* clone copies file descriptors, but doesn't share */
void test3(void *stack)
{
int fd = open("tmp", O_WRONLY|O_CREATE);
assert(fd == 3);
int clone_pid = clone(worker3, 0, 0, stack);
assert(clone_pid > 0);
while(!newfd);
assert(write(newfd, "goodbye\n", 8) == -1);
printf("TEST3 PASSED\n");
void *join_stack;
int join_pid = join(&join_stack);
assert(join_pid == clone_pid);
}
/* clone with bad stack argument */
void test4(void *stack)
{
assert(clone(worker4, 0, 0, stack+4) == -1);
printf("TEST4 PASSED\n");
}
/* clone and join syscalls */
void test5(void *stack)
{
global = 1;
int arg1 = 42, arg2 = 24;
int clone_pid = clone(worker5, &arg1, &arg2, stack);
assert(clone_pid > 0);
void *join_stack;
int join_pid = join(&join_stack);
assert(join_pid == clone_pid);
assert(stack == join_stack);
assert(global == 2);
printf("TEST5 PASSED\n");
}
/* join argument checking */
void test6(void *stack)
{
global = 1;
int arg1 = 42, arg2 = 24;
int clone_pid = clone(worker5, &arg1, &arg2, stack);
assert(clone_pid > 0);
sbrk(PGSIZE);
void **join_stack = (void**) ((uint64)sbrk(0) - 8);
assert(join((void**)((uint64)join_stack + 4)) == -1);
assert(join(join_stack) == clone_pid);
assert(stack == *join_stack);
assert(global == 2);
printf("TEST6 PASSED\n");
}
/* join should not handle child processes (forked) */
void test7(void *stack)
{
global = 1;
int fork_pid = fork();
if(fork_pid == 0) {
exit(0);
}
assert(fork_pid > 0);
void *join_stack;
int join_pid = join(&join_stack);
assert(join_pid == -1);
assert(wait(0) > 0);
printf("TEST7 PASSED\n");
}
/* join, not wait, should handle threads */
void test8(void *stack)
{
global = 1;
int arg1 = 42, arg2 = 24;
int clone_pid = clone(worker5, &arg1, &arg2, stack);
assert(clone_pid > 0);
sleep(10);
assert(wait(0) == -1);
void *join_stack;
int join_pid = join(&join_stack);
assert(join_pid == clone_pid);
assert(stack == join_stack);
assert(global == 2);
printf("TEST8 PASSED\n");
}
/* set up stack correctly (and without extra items) */
void test9(void *stack)
{
global = 1;
int clone_pid = clone(worker6, stack, 0, stack);
assert(clone_pid > 0);
while(global != 5);
printf("TEST9 PASSED\n");
void *join_stack;
int join_pid = join(&join_stack);
assert(join_pid == clone_pid);
}
void (*functions[])() = {test1, test2, test3, test4, test5, test6, test7, test8, test9};
int
main(int argc, char *argv[])
{
int len = sizeof(functions) / sizeof(functions[0]);
for(int i = 0; i < len; i++) {
ppid = getpid();
void *stack, *p = malloc(PGSIZE * 2);
assert(p != NULL);
stack = ((uint64)p % PGSIZE) ? (p + (PGSIZE - (uint64)p % PGSIZE)) : p;
(*functions[i])(stack);
free(p);
}
exit(0);
}
void
worker(void *arg1, void *arg2) {
assert(global == 1);
global = 5;
exit(0);
}
void
worker2(void *arg1, void *arg2) {
int tmp1 = *(int*)arg1;
int tmp2 = *(int*)arg2;
*(int*)arg1 = 44;
*(int*)arg2 = 55;
assert(global == 1);
global = tmp1 + tmp2;
exit(0);
}
void
worker3(void *arg1, void *arg2) {
assert(write(3, "hello\n", 6) == 6);
xchg(&newfd, open("tmp2", O_WRONLY|O_CREATE));
exit(0);
}
void
worker4(void *arg1, void *arg2) {
exit(0);
}
void
worker5(void *arg1, void *arg2) {
int tmp1 = *(int*)arg1;
int tmp2 = *(int*)arg2;
assert(tmp1 == 42);
assert(tmp2 == 24);
assert(global == 1);
global++;
exit(0);
}
void
worker6(void *arg1, void *arg2) {
// arg1 -> top stack
// arg1 -8 -> ra
// arg1 -16 -> fp
// arg1 - 24 -> a0
// arg1 - 32 -> a1
assert(*((uint64*) (arg1 + 2 * PGSIZE - 8)) == 0xffffffffffffffff);
assert((uint64)&arg2 == ((uint64)arg1 + 2 * PGSIZE - 32));
assert((uint64)&arg1 == ((uint64)arg1 + 2 * PGSIZE - 24));
global = 5;
exit(0);
}
void
worker7(void *arg1, void *arg2) {
int arg1_int = *(int*)arg1;
int arg2_int = *(int*)arg2;
assert(arg1_int == 35);
assert(arg2_int == 42);
assert(global == 1);
global++;
exit(0);
}
thread_create, thread_join 测试13个
#include "kernel/types.h"
#include "user.h"
#include "kernel/fcntl.h"
#include "kernel/riscv.h"
#undef NULL
#define NULL ((void*)0)
int ppid;
int global = 1;
uint64 size = 0;
lock_t lock, lock2;
int num_threads = 30;
int loops = 10;
int* global_arr;
#define assert(x) if (x) {} else { \
printf("%s: %d ", __FILE__, __LINE__); \
printf("assert failed (%s)\n", # x); \
printf("TEST FAILED\n"); \
kill(ppid); \
exit(0); \
}
void worker(void *arg1, void *arg2);
void worker2(void *arg1, void *arg2);
void worker3(void *arg1, void *arg2);
void worker4(void *arg1, void *arg2);
void worker5(void *arg1, void *arg2);
void worker6(void *arg1, void *arg2);
void worker7(void *arg1, void *arg2);
void merge_sort(void *array, void *size);
void worker9(void *array, void *size);
void worker10(void *array, void *size);
void worker11(void *array, void *size);
void worker12(void *array, void *size);
void worker13(void *array, void *size);
/* thread user library functions */
void test1()
{
int arg1 = 35;
int arg2 = 42;
int thread_pid = thread_create(worker, &arg1, &arg2);
assert(thread_pid > 0);
int join_pid = thread_join();
assert(join_pid == thread_pid);
assert(global == 2);
printf("TEST1 PASSED\n");
}
/* memory leaks from thread library? */
void test2()
{
int i, thread_pid, join_pid;
for(i = 0; i < 2000; i++) {
global = 1;
thread_pid = thread_create(worker2, 0, 0);
assert(thread_pid > 0);
join_pid = thread_join();
assert(join_pid == thread_pid);
assert(global == 5);
assert((uint64)sbrk(0) < (150 * 4096) && "shouldn't even come close");
}
printf("TEST2 PASSED\n");
}
/* check that address space size is updated in threads */
void test3()
{
global = 0;
int arg1 = 11, arg2 = 22;
lock_init(&lock);
lock_init(&lock2);
lock_acquire(&lock);
lock_acquire(&lock2);
for (int i = 0; i < num_threads; i++) {
int thread_pid = thread_create(worker3, &arg1, &arg2);
assert(thread_pid > 0);
}
size = (uint64)sbrk(0);
while (global < num_threads) {
lock_release(&lock);
sleep(2);
lock_acquire(&lock);
}
global = 0;
sbrk(10000);
size = (uint64)sbrk(0);
lock_release(&lock);
while (global < num_threads) {
lock_release(&lock2);
sleep(2);
lock_acquire(&lock2);
}
lock_release(&lock2);
for (int i = 0; i < num_threads; i++) {
int join_pid = thread_join();
assert(join_pid > 0);
}
printf("TEST3 PASSED\n");
}
/* multiple threads with some depth of function calls */
uint fib(uint n) {
if (n == 0) {
return 0;
} else if (n == 1) {
return 1;
} else {
return fib(n - 1) + fib(n - 2);
}
}
void test4()
{
assert(fib(28) == 317811);
int arg1 = 11, arg2 = 22;
for (int i = 0; i < num_threads; i++) {
int thread_pid = thread_create(worker4, &arg1, &arg2);
assert(thread_pid > 0);
}
for (int i = 0; i < num_threads; i++) {
int join_pid = thread_join();
assert(join_pid > 0);
}
printf("TEST4 PASSED\n");
}
/* no exit call in thread, should trap at bogus address */
void test5()
{
int arg1 = 42, arg2 = 24;
int thread_pid = thread_create(worker5, &arg1, &arg2);
assert(thread_pid > 0);
int join_pid = thread_join();
assert(join_pid == thread_pid);
assert(global == 2);
printf("TEST5 PASSED\n");
}
/* test lock correctness */
void test6()
{
global = 0;
lock_init(&lock);
int i;
for (i = 0; i < num_threads; i++) {
int thread_pid = thread_create(worker6, 0, 0);
assert(thread_pid > 0);
}
for (i = 0; i < num_threads; i++) {
int join_pid = thread_join();
assert(join_pid > 0);
}
assert(global == num_threads * loops);
printf("TEST6 PASSED\n");
}
/* nested thread user library functions */
void test7()
{
int arg1 = 35;
int arg2 = 42;
int thread_pid = thread_create(worker7, &arg1, &arg2);
assert(thread_pid > 0);
int join_pid = thread_join();
assert(join_pid == thread_pid);
assert(global == 3);
printf("TEST7 PASSED\n");
}
/* merge sort using nested threads */
void test8()
{
/*
1. Create global array and populate it
2. invoke merge sort (array ptr, size)
Merge sort:
0. base case - size = 1 --> return
1. thread create with merge sort (array left, size/2)
2. thread create with merge sort (array + size/2, size - size/2)
3. join both threads
4. Merge function
*/
int size = 11;
global_arr = (int*)malloc(size * sizeof(int));
for(int i = 0; i < size; i++){
global_arr[i] = size - i - 1;
}
int thread_pid = thread_create(merge_sort, global_arr, &size);
assert(thread_pid > 0);
int join_pid = thread_join();
assert(join_pid == thread_pid);
assert(global_arr[0] == 0);
assert(global_arr[5] == 5);
assert(global_arr[10] == 10);
printf("TEST8 PASSED\n");
}
/* test lock correctness using nested threads */
void test9()
{
global = 0;
lock_init(&lock);
int i;
for (i = 0; i < num_threads; i++) {
int thread_pid = thread_create(worker9, 0, 0);
assert(thread_pid > 0);
}
for (i = 0; i < num_threads; i++) {
int join_pid = thread_join();
assert(join_pid > 0);
}
assert(global == num_threads * 2);
printf("TEST9 PASSED\n");
}
/* no exit call in nested thread, should trap at bogus address */
void test10()
{
int arg1 = 42, arg2 = 24;
int thread_pid = thread_create(worker10, &arg1, &arg2);
assert(thread_pid > 0);
int join_pid = thread_join();
assert(join_pid == thread_pid);
assert(global == 3);
printf("TEST10 PASSED\n");
}
/* check that address space size is updated in threads */
void test11()
{
int arg1 = 11, arg2 = 22;
size = (uint64)sbrk(0);
int thread_pid = thread_create(worker11, &arg1, &arg2);
assert(thread_pid > 0);
int join_pid = thread_join();
assert(join_pid > 0);
printf("TEST11 PASSED\n");
}
/* check that thread stack overflow, should trap */
void test12()
{
int arg1 = 11, arg2 = 22;
size = (uint64)sbrk(0);
int thread_pid = thread_create(worker12, &arg1, &arg2);
assert(thread_pid > 0);
int join_pid = thread_join();
assert(join_pid > 0);
assert(global == 3);
printf("TEST12 PASSED\n");
}
/* check no malloc stack race condition */
void test13()
{
num_threads = 30;
int i;
int arg1 = 35;
int arg2 = 42;
uint64 origin = (uint64)sbrk(0);
for (i = 0; i < num_threads; i++) {
int thread_pid = thread_create(worker13, &arg1, &arg2);
assert(thread_pid > 0);
}
for (i = 0; i < num_threads; i++) {
int join_pid = thread_join();
assert(join_pid > 0);
}
assert((uint64)sbrk(0) < (origin + (16 + num_threads * 2 * 3) * 4096) && "shouldn't even come close");
printf("TEST13 PASSED\n");
}
void (*functions[])() = {test13};
int
main(int argc, char *argv[])
{
int len = sizeof(functions) / sizeof(functions[0]);
for(int i = 0; i < len; i++) {
global = 1;
ppid = getpid();
(*functions[i])();
}
exit(0);
}
void
worker(void *arg1, void *arg2) {
int arg1_int = *(int*)arg1;
int arg2_int = *(int*)arg2;
assert(arg1_int == 35);
assert(arg2_int == 42);
assert(global == 1);
global++;
exit(0);
}
void
worker2(void *arg1, void *arg2) {
assert(global == 1);
global += 4;
exit(0);
}
void
worker3(void *arg1, void *arg2) {
lock_acquire(&lock);
assert((uint64)sbrk(0) == size);
global++;
lock_release(&lock);
lock_acquire(&lock2);
assert((uint64)sbrk(0) == size);
global++;
lock_release(&lock2);
exit(0);
}
void
worker4(void *arg1, void *arg2) {
int tmp1 = *(int*)arg1;
int tmp2 = *(int*)arg2;
assert(tmp1 == 11);
assert(tmp2 == 22);
assert(global == 1);
assert(fib(2) == 1);
assert(fib(3) == 2);
assert(fib(9) == 34);
assert(fib(15) == 610);
exit(0);
}
void
worker5(void *arg1, void *arg2) {
int tmp1 = *(int*)arg1;
int tmp2 = *(int*)arg2;
assert(tmp1 == 42);
assert(tmp2 == 24);
assert(global == 1);
global++;
// no exit() in thread
}
void
worker6(void *arg1, void *arg2) {
int i, j, tmp;
for (i = 0; i < loops; i++) {
lock_acquire(&lock);
tmp = global;
for(j = 0; j < 50; j++); // take some time
global = tmp + 1;
lock_release(&lock);
}
exit(0);
}
void nested_worker(void *arg1, void *arg2){
int arg1_int = *(int*)arg1;
int arg2_int = *(int*)arg2;
assert(arg1_int == 35);
assert(arg2_int == 42);
assert(global == 2);
global++;
exit(0);
}
void
worker7(void *arg1, void *arg2) {
int arg1_int = *(int*)arg1;
int arg2_int = *(int*)arg2;
assert(arg1_int == 35);
assert(arg2_int == 42);
assert(global == 1);
global++;
int nested_thread_pid = thread_create(nested_worker, &arg1_int, &arg2_int);
int nested_join_pid = thread_join();
assert(nested_join_pid == nested_thread_pid);
exit(0);
}
void merge(int* array, int* array_right,int size_left, int size_right,int*temp_array){
int i = 0;
int j = 0;
int k = 0;
while(i < size_left && j < size_right){
if(array[i] < array_right[j]){
temp_array[k] = array[i];
i++;
}
else{
temp_array[k] = array_right[j];
j++;
}
k++;
}
while(i < size_left){
temp_array[k] = array[i];
i++;
k++;
}
while(j < size_right){
temp_array[k] = array_right[j];
j++;
k++;
}
for(int i = 0; i < size_left + size_right; i++){
array[i] = temp_array[i];
}
}
void merge_sort(void *arg1, void *arg2) {
int *array = (int*)arg1;
int size = *(int*)arg2;
if (size==1){
exit(0);
}
int size_left = size/2;
int size_right = size-size/2;
int* array_right = (int*)(array + size_left);
int nested_thread_pid_l = thread_create(merge_sort, array, &size_left);
int nested_thread_pid_r = thread_create(merge_sort, array_right, &size_right);
int nested_join_pid_1 = thread_join();
int nested_join_pid_2 = thread_join();
int* temp_array = malloc(size*sizeof(int));
merge(array,array_right,size_left,size_right,temp_array);
free(temp_array);
assert(nested_thread_pid_l == nested_join_pid_1 || nested_thread_pid_l == nested_join_pid_2);
assert(nested_thread_pid_r == nested_join_pid_1 || nested_thread_pid_r == nested_join_pid_2);
exit(0);
}
void nest_worker(void *arg1,void *arg2){
int j;
lock_acquire(&lock);
for(j=0;j<50;j++);
global++;
lock_release(&lock);
exit(0);
}
void
worker9(void *arg1, void *arg2) {
lock_acquire(&lock);
int j;
for(j = 0; j < 50; j++); // take some time
global++;
lock_release(&lock);
int nested_thread_pid = thread_create(nest_worker, 0, 0);
assert(nested_thread_pid > 0);
int nested_join_pid = thread_join();
assert(nested_join_pid > 0);
assert(nested_thread_pid==nested_join_pid);
exit(0);
}
void nested_worker2(void *arg1, void *arg2){
int arg1_int = *(int*)arg1;
int arg2_int = *(int*)arg2;
assert(arg1_int == 42);
assert(arg2_int == 24);
assert(global == 2);
global++;
// no exit() in thread
}
void
worker10(void *arg1, void *arg2) {
int tmp1 = *(int*)arg1;
int tmp2 = *(int*)arg2;
assert(tmp1 == 42);
assert(tmp2 == 24);
assert(global == 1);
global++;
int nested_thread_pid = thread_create(nested_worker2, &tmp1, &tmp2);
assert(nested_thread_pid > 0);
for(int j=0;j<10000;j++);
int nested_join_pid = thread_join();
assert(nested_join_pid)
assert(nested_join_pid == nested_thread_pid);
exit(0);
}
void nest_worker3(void *arg1, void *arg2)
{
lock_acquire(&lock);
assert((uint64)sbrk(0) == size);
global++;
lock_release(&lock);
lock_acquire(&lock2);
assert((uint64)sbrk(0) == size);
global++;
lock_release(&lock2);
exit(0);
}
void worker11(void *arg1, void *arg2) {
num_threads = 1;
lock_init(&lock);
lock_init(&lock2);
lock_acquire(&lock);
lock_acquire(&lock2);
int nested_thread_id = thread_create(nest_worker3, 0, 0);
assert(nested_thread_id > 0);
size = (uint64)sbrk(0);
while (global < num_threads) {
lock_release(&lock);
sleep(2);
lock_acquire(&lock);
}
global = 0;
sbrk(10000);
size = (uint64)sbrk(0);
lock_release(&lock);
while (global < num_threads) {
lock_release(&lock2);
sleep(2);
lock_acquire(&lock2);
}
lock_release(&lock2);
int nested_join_pid = thread_join();
assert(nested_join_pid > 0);
exit(0);
}
void call_forever()
{
int k = 3;
global = k;
call_forever();
}
void
worker12(void *arg1, void *arg2) {
int tmp1 = *(int*)arg1;
int tmp2 = *(int*)arg2;
assert(tmp1 == 11);
assert(tmp2 == 22);
assert(global == 1);
call_forever();
exit(0);
}
void empty(void *arg1, void *arg2)
{
int arg1_int = *(int*)arg1;
int arg2_int = *(int*)arg2;
assert(arg1_int == 35);
assert(arg2_int == 42);
exit(0);
}
void
worker13(void *arg1, void *arg2) {
sleep(3);
int arg1_int = *(int*)arg1;
int arg2_int = *(int*)arg2;
assert(arg1_int == 35);
assert(arg2_int == 42);
sleep(3);
int nested_thread_pid = thread_create(empty, &arg1_int, &arg2_int);
int nested_join_pid = thread_join();
assert(nested_join_pid == nested_thread_pid);
exit(0);
}
测试结果
跑一下回归测试 usertests