diff --git a/lib/cachet.ml b/lib/cachet.ml index 06cb03c..89a26aa 100644 --- a/lib/cachet.ml +++ b/lib/cachet.ml @@ -37,6 +37,8 @@ let memmove src ~src_off dst ~dst_off ~len = let dst = Bigarray.Array1.sub dst dst_off len in Bigarray.Array1.blit src dst +let invalid_argf fmt = Format.kasprintf invalid_arg fmt + module Bstr = struct type t = bigstring @@ -396,7 +398,6 @@ external hash : (int32[@unboxed]) -> int -> (int32[@unboxed]) [@@noalloc] let hash h d = Int32.to_int (hash h d) -let failwithf fmt = Format.ksprintf (fun str -> failwith str) fmt type slice = { offset: int; length: int; payload: bigstring } @@ -480,7 +481,8 @@ let load t ?(len = 1) logical_address = if len > 1 lsl t.pagesize then invalid_arg "Cachet.load: you can not load more than a page"; if logical_address < 0 then - invalid_arg "Cachet.load: a logical address must be positive"; + invalid_argf "Cachet.load: a logical address must be positive (%08x)" + logical_address; let page = logical_address lsr t.pagesize in let hash = hash 0l (page lsl t.pagesize) land ((1 lsl t.cachesize) - 1) in let offset = logical_address land ((t.pagesize lsl 1) - 1) in @@ -507,12 +509,16 @@ let invalidate t ~off:logical_address ~len = let is_aligned x = x land ((1 lsl 2) - 1) == 0 +exception Out_of_bounds of int + +let[@inline never] out_of_bounds offset = raise (Out_of_bounds offset) + let get_uint8 t logical_address = match load t ~len:1 logical_address with | Some { payload; _ } -> let offset = logical_address land ((1 lsl t.pagesize) - 1) in Bstr.get_uint8 payload offset - | None -> failwithf "Cachet.get_uint8" + | None -> out_of_bounds logical_address let get_int8 t logical_address = (get_uint8 t logical_address lsl (Sys.int_size - 8)) asr (Sys.int_size - 8) @@ -523,7 +529,7 @@ let blit_to_bytes t ~src_off:logical_address buf ~dst_off ~len = let off = logical_address land ((1 lsl t.pagesize) - 1) in if is_aligned off && (1 lsl t.pagesize) - off >= len then begin match load t ~len logical_address with - | None -> failwithf "Cachet.blit_to_bytes" + | None -> out_of_bounds logical_address | Some slice -> Bstr.blit_to_bytes slice.payload ~src_off:off buf ~dst_off:0 ~len end @@ -619,13 +625,14 @@ let iter_with_len t len ~fn logical_address = let logical_address = offset + (1 lsl t.pagesize) in match load t logical_address with | Some { payload; length; _ } -> - if len - max > length then failwith "Chat.iter_with_len"; + if len - max > length then + out_of_bounds (logical_address + (len - max - 1)); for i = 0 to len - max - 1 do fn (Bstr.get_uint8 payload i) done - | None -> failwith "Chat.iter_with_len" + | None -> out_of_bounds logical_address end - | None -> failwith "Chat.iter_with_len" + | None -> out_of_bounds logical_address end let iter t ?len ~fn logical_address = diff --git a/lib/cachet.mli b/lib/cachet.mli index 764515b..e255021 100644 --- a/lib/cachet.mli +++ b/lib/cachet.mli @@ -281,13 +281,20 @@ val invalidate : 'fd t -> off:int -> len:int -> unit 8-bit or 16-bit integers represented by [int] values sign-extend (resp. zero-extend) their result. *) +exception Out_of_bounds of int +(** If Cachet tries to retrieve a byte outside the block device, this exception is raised. *) + val get_int8 : 'fd t -> int -> int (** [get_int8 t logical_address] is [t]'s signed 8-bit integer starting at byte - index [logical_address]. *) + index [logical_address]. + + @raise Out_of_bounds if [logical_address] is not accessible. *) val get_uint8 : 'fd t -> int -> int (** [get_uint8 t logical_address] is [t]'s unsigned 8-bit integer starting at - byte index [logical_address]. *) + byte index [logical_address]. + + @raise Out_of_bounds if [logical_address] is not accessible. *) val get_uint16_ne : 'fd t -> int -> int val get_uint16_le : 'fd t -> int -> int @@ -309,8 +316,8 @@ val get_string : 'fd t -> len:int -> int -> string You can use {!val:syscalls} to find out how many times [get_string] can call [map] at most. - @raise Failure if the [map] function cannot give us enough to copy [len] - bytes. *) + @raise Out_of_bounds if [logical_address] and [len] byte(s) are not + accessible. *) val get_seq : 'fd t -> int -> string Seq.t val next : 'fd t -> slice -> slice option