-- DNS over HTTP(S) for Knot Resolver --[[ -- config modules.load('http') doh = require 'doh' http.add_interface { host = '192.0.2.1', port = 443, cert = 'fullchain.pem', key = 'privkey.pem', endpoints = { ['/dns-resolve'] = { 'application/json', doh.serve_doh } }, } ]]-- local ffi = require('ffi') local knot = ffi.load(libknot_SONAME) local condition = require('cqueues.condition') local http_util = require('http.util') local headers = require('http.headers') local basexx = require('basexx') local bit = require('bit') local M = { wireformat = true, json = true, logging = false, } -- synchronous resolve local function resolve_s(query) local pkt local cond = condition.new() local waiting, done = false, false local resolve_cb = function (req) pkt = kres.pkt_t(kres.request_t(req).answer) if waiting then cond:signal() end done = true end local finish_cb = ffi.cast('trace_callback_f', resolve_cb) resolve { name = query.name, class = query.class, type = query.type, options = { query["do"] and "DNSSEC_WANT" or nil, query["cd"] and "DNSSEC_CD" or nil }, init = function (req) req = kres.request_t(req) req.trace_finish = finish_cb end } if not done then waiting = true cond:wait() end finish_cb:free() if done then assert(ffi.istype(ffi.typeof("knot_pkt_t"), pkt), "pkt: assertion failed") return pkt else return false end end local function parse_query_pkt(pkt) if #pkt < 12 or #pkt > 65535 then return false, "parse failed" end local qpkt = kres.packet(#pkt, pkt) assert(ffi.istype(ffi.typeof("knot_pkt_t"), qpkt), "qpkt: assertion failed") if not qpkt:parse() then return false, "parse failed" end if qpkt:qdcount() ~= 1 then return false, "invalid qdcount" end local query = { id = qpkt:id(), name = kres.dname2str(qpkt:qname()), class = qpkt:qclass(), type = qpkt:qtype(), rd = qpkt:rd(), cd = qpkt:cd(), raw = true, } -- edns0 if qpkt.opt_rr ~= nil then query.opt = qpkt.opt_rr query.buffsize = query.opt:class() query.exflags = query.opt:ttl() query["do"] = bit.band(query.exflags, 0x8000) ~= 0 end return query end local function parse_query_string(qarg) local query = { id = 0, name = qarg.name, class = kres.class.IN, raw = qarg.encoding == "raw" } if not qarg.type or #qarg.type == 0 then query.type = kres.type.A else local n = tonumber(qarg.type) if n and n > 0 and n < 65536 and n == math.floor(n) then query.type = n else query.type = kres.type[qarg.type:upper()] if not query.type then return false, "invalid query type" end end end for _, flag in pairs{"cd", "do"} do query[flag] = qarg[flag] == "1" or qarg[flag] == "true" end return query end local function section2json(rrsets) local r = {} for i = 1, #rrsets do if rrsets[i].type ~= kres.type.OPT and rrsets[i].type ~= kres.type.TSIG then r[#r+1] = { name = kres.dname2str(rrsets[i].owner), type = rrsets[i].type, TTL = rrsets[i].ttl, data = kres.rr2str(rrsets[i]):gsub("[^\t]*\t%d+\t%w+\t", "") } end end return #r > 0 and r or nil end local function pkt2json(pkt) local result = { Status = pkt:rcode(), RD = pkt:rd(), TC = pkt:tc(), AA = pkt:aa(), QR = pkt:qr(), CD = pkt:cd(), AD = pkt:ad(), RA = pkt:ra(), Question = { { name = kres.dname2str(pkt:qname()), type = pkt:qtype(), }, }, Answer = section2json(pkt:section(kres.section.ANSWER)), Authority = section2json(pkt:section(kres.section.AUTHORITY)), Additional = section2json(pkt:section(kres.section.ADDITIONAL)), } return tojson(result) end function M.serve_doh(h, stream) local method = h:get(':method') local query = {} local arg = {} local _, addr, port = stream.connection:peername() addr = port and addr .. "@" .. port or "[unknown]" -- parse query local err if method == "POST" then if M.wireformat then query, err = parse_query_pkt(stream:get_body_as_string()) else return 405, "method not allowed" end elseif method == "GET" then for k, v in http_util.query_args(h:get(':path'):match('[^?]+?(.*)') or "") do arg[k] = v end if M.wireformat and arg.dns then query, err = parse_query_pkt(basexx.from_base64(arg.dns)) elseif M.json and arg.name then query, err = parse_query_string(arg) else return 400, "bad request" end elseif method == "HEAD" then return 200 else return 405, "method not allowed" end if not query then if M.logging then log("[doh] %s %s", addr, err) end return 400, "bad request" end if M.logging then if M.json and arg.name then log("[doh] %s %q", addr, h:get(":path")) else log("[doh] %s %s/%s/%s %s%s%s%s", addr, query.name, kres.tostring.class[query.class], kres.tostring.type[query.type], query.rd and "+" or "-", query.opt and "E" or "", query["do"] and "D" or "", query.cd and "C" or "") end end -- resolve local pkt = resolve_s(query) if not pkt then return 504, "timeout" end -- max-age local maxage local maxttl = cache.max_ttl() -- smallest ttl in the answer section for i = 1, pkt:ancount() do local ttl = pkt:section(kres.section.ANSWER)[i].ttl maxage = (maxage or maxttl) > ttl and ttl or maxage end -- or soa minimum if not maxage then for i = 1, pkt:nscount() do if pkt:section(kres.section.AUTHORITY)[i].type == kres.type.SOA then local minimum = tonumber(basexx.to_bit(pkt:section(kres.section.AUTHORITY)[i].rdata:sub(-4)), 2) local ttl = pkt:section(kres.section.AUTHORITY)[i].ttl minimum = minimum > ttl and ttl or minimum maxage = minimum > maxttl and maxttl or minimum break end end end -- build response body local body if query.raw then pkt:id(query.id) -- fix id body = pkt:towire() else body = pkt2json(pkt) end -- response header local hsend = headers.new() hsend:append(":status", "200") hsend:append("content-length", tostring(#body)) if maxage then hsend:append("cache-control", "max-age="..tostring(maxage)) end if query.raw then local ac = h:get('accept') or "" local ct = "application/dns-message" for _, v in pairs{"application/dns-message", -- draft-ietf-doh 07- "message/dns", -- draft-ietf-doh 06 "application/dns-udpwireformat",-- draft-ietf-doh -05 } do if ac:find(v) then ct = v break end end hsend:append("content-type", ct) else hsend:append("content-type", "application/json") --hsend:append("content-type", "application/x-javascript") -- google --hsend:append("content-type", "application/dns-json") -- cloudflare end assert(stream:write_headers(hsend, false)) -- response body assert(stream:write_chunk(body, true)) -- disable default action return false end return M