本系列整理了10个工作量和难度适中的Golang小项目,适合已经掌握Go语法的工程师进一步熟练语法和常用库的用法。
问题描述:
实现一个网络爬虫,以输入的URL为起点,使用广度优先顺序访问页面。
要点:
实现对多个页面的并发访问,同时访问的页面数由参数 -concurrency 指定,默认为 20。
使用 -depth 指定访问的页面深度,默认为 3。
注意已经访问过的页面不要重复访问。
扩展:
将访问到的页面写入到本地以创建目标网站的本地镜像,注意,只有指定域名下的页面需要采集,写入本地的页面里的<a>元素的href的值需要被修改为指向镜像页面,而不是原始页面。
实现
import (
"bytes"
"flag"
"fmt"
"golang.org/x/net/html"
"io"
"log"
"net/http"
"net/url"
"os"
"path/filepath"
"strings"
"sync"
"time"
)
type URLInfo struct {
url string
depth int
}
var base *url.URL
func forEachNode(n *html.Node, pre, post func(n *html.Node)){
if pre != nil{
pre(n)
}
for c := n.FirstChild; c != nil; c = c.NextSibling{
forEachNode(c, pre, post)
}
if post != nil{
post(n)
}
}
func linkNodes(n *html.Node) []*html.Node {
var links []*html.Node
visitNode := func(n *html.Node) {
if n.Type == html.ElementNode && n.Data == "a" {
links = append(links, n)
}
}
forEachNode(n, visitNode, nil)
return links
}
func linkURLs(linkNodes []*html.Node, base *url.URL) []string {
var urls []string
for _, n := range linkNodes {
for _, a := range n.Attr {
if a.Key != "href" {
continue
}
link, err := base.Parse(a.Val)
// ignore bad and non-local URLs
if err != nil {
log.Printf("skipping %q: %s", a.Val, err)
continue
}
if link.Host != base.Host {
//log.Printf("skipping %q: non-local host", a.Val)
continue
}
if strings.HasPrefix(link.String(), "javascript"){
continue
}
urls = append(urls, link.String())
}
}
return urls
}
func rewriteLocalLinks(linkNodes []*html.Node, base *url.URL) {
for _, n := range linkNodes {
for i, a := range n.Attr {
if a.Key != "href" {
continue
}
link, err := base.Parse(a.Val)
if err != nil || link.Host != base.Host {
continue // ignore bad and non-local URLs
}
link.Scheme = ""
link.Host = ""
link.User = nil
a.Val = link.String()
n.Attr[i] = a
}
}
}
func Extract(url string)(urls []string, err error){
timeout := time.Duration(10 * time.Second)
client := http.Client{
Timeout: timeout,
}
resp, err := client.Get(url)
if err != nil{
fmt.Println(err)
return nil, err
}
if resp.StatusCode != http.StatusOK{
resp.Body.Close()
return nil, fmt.Errorf("getting %s:%s", url, resp.StatusCode)
}
if err != nil{
return nil, fmt.Errorf("parsing %s as HTML: %v", url, err)
}
u, err := base.Parse(url)
if err != nil {
return nil, err
}
if base.Host != u.Host {
log.Printf("not saving %s: non-local", url)
return nil, nil
}
var body io.Reader
contentType := resp.Header["Content-Type"]
if strings.Contains(strings.Join(contentType, ","), "text/html") {
doc, err := html.Parse(resp.Body)
resp.Body.Close()
if err != nil {
return nil, fmt.Errorf("parsing %s as HTML: %v", u, err)
}
nodes := linkNodes(doc)
urls = linkURLs(nodes, u)
rewriteLocalLinks(nodes, u)
b := &bytes.Buffer{}
err = html.Render(b, doc)
if err != nil {
log.Printf("render %s: %s", u, err)
}
body = b
}
err = save(resp, body)
return urls, err
}
func crawl(url string) []string{
list, err := Extract(url)
if err != nil{
log.Print(err)
}
return list
}
func save(resp *http.Response, body io.Reader) error {
u := resp.Request.URL
filename := filepath.Join(u.Host, u.Path)
if filepath.Ext(u.Path) == "" {
filename = filepath.Join(u.Host, u.Path, "index.html")
}
err := os.MkdirAll(filepath.Dir(filename), 0777)
if err != nil {
return err
}
fmt.Println("filename:", filename)
file, err := os.Create(filename)
if err != nil {
return err
}
if body != nil {
_, err = io.Copy(file, body)
} else {
_, err = io.Copy(file, resp.Body)
}
if err != nil {
log.Print("save: ", err)
}
err = file.Close()
if err != nil {
log.Print("save: ", err)
}
return nil
}
func parallellyCrawl(initialLinks string, concurrency, depth int){
worklist := make(chan []URLInfo, 1)
unseenLinks := make(chan URLInfo, 1)
//值为1时表示进入unseenLinks队列,值为2时表示crawl完成
seen := make(map[string] int)
seenLock := sync.Mutex{}
var urlInfos []URLInfo
for _, url := range strings.Split(initialLinks, " "){
urlInfos = append(urlInfos, URLInfo{url, 1})
}
go func() {worklist <- urlInfos}()
go func() {
for{
time.Sleep(1 * time.Second)
seenFlag := true
seenLock.Lock()
for k := range seen{
if seen[k] == 1{
seenFlag = false
}
}
seenLock.Unlock()
if seenFlag && len(worklist) == 0{
close(unseenLinks)
close(worklist)
break
}
}
}()
for i := 0; i < concurrency; i++{
go func() {
for link := range unseenLinks{
foundLinks := crawl(link.url)
var urlInfos []URLInfo
for _, u := range foundLinks{
urlInfos = append(urlInfos, URLInfo{u, link.depth + 1})
}
go func(finishedUrl string) {
worklist <- urlInfos
seenLock.Lock()
seen[finishedUrl] = 2
seenLock.Unlock()
}(link.url)
}
}()
}
for list := range worklist{
for _, link := range list {
if link.depth > depth{
continue
}
seenLock.Lock()
_, ok := seen[link.url]
seenLock.Unlock()
if !ok{
seenLock.Lock()
seen[link.url] = 1
seenLock.Unlock()
unseenLinks <- link
}
}
}
fmt.Printf("共访问了%d个页面", len(seen))
}
func main() {
var maxDepth int
var concurrency int
var initialLink string
flag.IntVar(&maxDepth, "d", 3, "max crawl depth")
flag.IntVar(&concurrency, "c", 20, "number of crawl goroutines")
flag.StringVar(&initialLink, "u", "", "initial link")
flag.Parse()
u, err := url.Parse(initialLink)
if err != nil {
fmt.Fprintf(os.Stderr, "invalid url: %s\n", err)
os.Exit(1)
}
base = u
parallellyCrawl(initialLink, concurrency, maxDepth)
}