Join us

How To Download File With Golang Goroutine/WaitGroup ?

1_aVRshnWQJNcxYkEfhN2i8g.png

What’s goroutine ?

Go language provides a special feature known as a Goroutines. A Goroutine is a function or method which executes independently and simultaneously in your program.

You are allowed to create multiple goroutines in a single program. You can create a goroutine simply by using go keyword as a prefixing to the function or method like this:

                package main
import "fmt"
func main() {
  go funcName()
}
func funcName() {
  // do something
  fmt.Println("Hello World")
}
            

But this code will not print “Hello World” because the `main` func is running on goroutine. So you can use WaitGroup of package `sync` to print word.

                package main
import (
 "fmt"
 "sync"
)
var wg sync.WaitGroup
func main() {
 wg.Add(1)
 go funcName()
 wg.Wait()
}
func funcName() {
 // do something
 fmt.Println("Hello World")
 wg.Done()
}
            

In this document, we’ll create a small program to download any file with download link.

So let’s go !!!

Get the repo for this article here.

First generate the project like this:

                go mod init file-downloader
            

You can give any name to your project, I decide to name “file-downloader”. Open project in your favorite IDE (vs-code, Golang…). Now create file and name it “main.go”.

                package main

func worker(url string) error {
  // code here
}
 
func main() {
	var url = "https://agritrop.cirad.fr/584726/1/Rapport.pdf"
	err := worker(url)
	if err != nil {
		log.Fatal(err)
		return
	}
}
            

In this code, we can see a function (worker) to which we pass a url, and we’ll develop our program here.

                func worker(url string) error {
	if strings.TrimSpace(url) == "" {
		return errors.New("invalid url")
	}
	return nil
}
            

We check that the url received is valid and return error or we continue.

                func worker(url string) error {
	if strings.TrimSpace(url) == "" {
		return errors.New("invalid url")
	}
	res, err := http.Head(url)
	if err != nil {
		return errors.New("unsupported protocol scheme")
	}
	if res.Header.Get("Accept-Ranges") != "bytes" {
		return errors.New("unable to download file with multithreads")
	}
  	urlSplit := strings.Split(url, "/")
	filename := urlSplit[len(urlSplit)-1]
  
	return nil
}
            

We make a request on the url to retrieve the headers.

  • check if url is valid url, else return an error
  • check if header “Accept-Ranges” exists. If not exist, return an error
  • recover file name

Why do we have to check this header (Accept-Ranges) ?

The Accept-Ranges HTTP response header is a marker used by the server to advertise its support for partial requests from the client for file downloads. The value of this field indicates the unit that can be used to define a range.

In the presence of an Accept-Ranges header, the browser may try to resume an interrupted download instead of trying to restart the download (see more…).

In our case the ranges will allow us to split the file to be downloaded into several parts which will be provided to the goroutines.

                func worker(url string) error {
	if strings.TrimSpace(url) == "" {
		return errors.New("invalid url")
	}
	res, err := http.Head(url)
	if err != nil {
		return errors.New("unsupported protocol scheme")
	}
	if res.Header.Get("Accept-Ranges") != "bytes" {
		return errors.New("unable to download file with multithreads")
	}
  	urlSplit := strings.Split(url, "/")
	filename := urlSplit[len(urlSplit)-1]
  
	// size of file
  	cntLen, err := strconv.Atoi(res.Header.Get("Content-Length"))
	if err != nil {
		return err
	}
	nbPart := 3
	offset := cntLen / nbPart
  
  	for i := 0; i < nbPart; i++ {
		name := fmt.Sprintf("part%d", i)
		start := i * offset
		end := (i + 1) * offset

		i := i

		func() {
			part, err := os.Create(name)
			if err != nil {
				return
			}
			defer part.Close()

			req, err := http.NewRequest(http.MethodGet, url, nil)
			if err != nil {
				return
			}

      			// Add range with start and end
			req.Header.Add("Range", fmt.Sprintf("bytes=%d-%d", start, end))
			res, err := client.Do(req)
			if err != nil {
				return
			}
			defer res.Body.Close()

			body, err := ioutil.ReadAll(res.Body)
			if err != nil {
				return
			}
			_, err = part.Write(body)
			if err != nil {
				return
			}
		}()
	}
  
	return nil
}
            

In this part, we declare variable witch contains the number of part of times we want to split the file, it represent the number of goroutines too. But, for the moment, we have not add the keyword go therefore the application don’t run in goroutine. The ideas is to:

  • Create a file (one part of file)
  • Make a request to download part of file in terms of range
  • Write in part of file

We repeat this three times in the loop because we define nbPart=3 .

Add the keyword go and waitgroups

                func worker(url string) error {
	if strings.TrimSpace(url) == "" {
		return errors.New("invalid url")
	}
	client := http.Client{}

	res, err := http.Head(url)
	if err != nil {
		return errors.New("unsupported protocol scheme")
	}
	urlSplit := strings.Split(url, "/")
	filename := urlSplit[len(urlSplit)-1]
	if res.Header.Get("Accept-Ranges") != "bytes" {
		return errors.New("unable to download file with multithreads")
	}

	cntLen, err := strconv.Atoi(res.Header.Get("Content-Length"))
	if err != nil {
		return err
	}
	nbPart := 3
	offset := cntLen / nbPart

  // waitgroup for goroutine
	wg := sync.WaitGroup{}

	for i := 0; i < nbPart; i++ {
		wg.Add(1)
		name := fmt.Sprintf("part%d", i)
		start := i * offset
		end := (i + 1) * offset

		i := i

    // add keyword go to run the function in goroutine.
		go func() {
			defer wg.Done()
			part, err := os.Create(name)
			if err != nil {
				return
			}
			defer part.Close()

			req, err := http.NewRequest(http.MethodGet, url, nil)
			if err != nil {
				return
			}

			req.Header.Add("Range", fmt.Sprintf("bytes=%d-%d", start, end))
			res, err := client.Do(req)
			if err != nil {
				return
			}
			defer res.Body.Close()

			body, err := ioutil.ReadAll(res.Body)
			if err != nil {
				return
			}
			_, err = part.Write(body)
			if err != nil {
				return
			}
		}()
	}

	wg.Wait()

	out, err := os.Create(filename)
	if err != nil {
		return err
	}
	defer out.Close()
	for i := 0; i < nbPart; i++ {
		name := fmt.Sprintf("part%d", i)
		file, err := ioutil.ReadFile(name)
		if err != nil {
			return err
		}
		out.WriteAt(file, int64(i*offset))

		if err := os.Remove(name); err != nil {
			return err
		}
	}

	return nil
}
            

After the loop for , we create a new file witch will aggragate all files and it’s finished.

Bonus

You can add a visual of download. Install this package:

                go get "github.com/schollz/progressbar/v3"
            
                func worker(url string) error {
	if strings.TrimSpace(url) == "" {
		return errors.New("invalid url")
	}
	client := http.Client{}

	res, err := http.Head(url)
	if err != nil {
		return errors.New("unsupported protocol scheme")
	}
	urlSplit := strings.Split(url, "/")
	filename := urlSplit[len(urlSplit)-1]
	if res.Header.Get("Accept-Ranges") != "bytes" {
		return errors.New("unable to download file with multithreads")
	}

	cntLen, err := strconv.Atoi(res.Header.Get("Content-Length"))
	if err != nil {
		return err
	}
	nbPart := 3
	offset := cntLen / nbPart

  // waitgroup for goroutine
	wg := sync.WaitGroup{}

	for i := 0; i < nbPart; i++ {
		wg.Add(1)
		name := fmt.Sprintf("part%d", i)
		start := i * offset
		end := (i + 1) * offset

		i := i

    		// add keyword go to run the function in goroutine.
		go func() {
			defer wg.Done()
			part, err := os.Create(name)
			if err != nil {
				return
			}
			defer part.Close()

			req, err := http.NewRequest(http.MethodGet, url, nil)
			if err != nil {
				return
			}

			req.Header.Add("Range", fmt.Sprintf("bytes=%d-%d", start, end))
			res, err := client.Do(req)
			if err != nil {
				return
			}
			defer res.Body.Close()
      
     		 	// progress bar to visualize an evolution of download
      			f, err := os.OpenFile(name, os.O_CREATE|os.O_WRONLY, 0644)
			if err != nil {
				return
			}
			defer f.Close()

			bar := progressbar.DefaultBytes(
				res.ContentLength,
				fmt.Sprintf("downloading-worker %d", i+1),
			)
			io.Copy(io.MultiWriter(f, bar), res.Body)
			// end progress bar

			body, err := ioutil.ReadAll(res.Body)
			if err != nil {
				return
			}
			_, err = part.Write(body)
			if err != nil {
				return
			}
		}()
	}

	wg.Wait()

	out, err := os.Create(filename)
	if err != nil {
		return err
	}
	defer out.Close()
	for i := 0; i < nbPart; i++ {
		name := fmt.Sprintf("part%d", i)
		file, err := ioutil.ReadFile(name)
		if err != nil {
			return err
		}
		out.WriteAt(file, int64(i*offset))

		if err := os.Remove(name); err != nil {
			return err
		}
	}

	return nil
}
            

Don’t forget to load and import all package. To run the project:

                go run main.go
            

Get the repo for this article here.

Hope you all enjoy it!

Be lenient on my English, I’m French guy. Thanks !!!

Follow to get updates.


Only registered users can post comments. Please, login or signup.

Start blogging about your favorite technologies, reach more readers and earn rewards!

Join other developers and claim your FAUN account now!

User Popularity
22

Influence

1k

Total Hits

1

Posts